mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Merge branch 'main' into better-server-params-and-fields
This commit is contained in:
@@ -90,7 +90,7 @@ This package is under active development and I welcome any contributions.
|
|||||||
To get started, clone the repository and install the package in development mode:
|
To get started, clone the repository and install the package in development mode:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
|
git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git
|
||||||
# Will need to be re-run any time vendor/llama.cpp is updated
|
# Will need to be re-run any time vendor/llama.cpp is updated
|
||||||
python3 setup.py develop
|
python3 setup.py develop
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ class Llama:
|
|||||||
llama_cpp.llama_sample_typical(
|
llama_cpp.llama_sample_typical(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
candidates=llama_cpp.ctypes.pointer(candidates),
|
candidates=llama_cpp.ctypes.pointer(candidates),
|
||||||
p=llama_cpp.c_float(1.0)
|
p=llama_cpp.c_float(1.0),
|
||||||
)
|
)
|
||||||
llama_cpp.llama_sample_top_p(
|
llama_cpp.llama_sample_top_p(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
@@ -637,10 +637,7 @@ class Llama:
|
|||||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
|
||||||
Llama._logits_to_logprobs(row)
|
|
||||||
for row in self.eval_logits
|
|
||||||
]
|
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
|||||||
import os
|
import os
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llama_cpp.server.app import app, init_llama
|
from llama_cpp.server.app import create_app
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_llama()
|
app = create_app()
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||||
|
|||||||
@@ -2,18 +2,18 @@ import os
|
|||||||
import json
|
import json
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import List, Optional, Union, Iterator, Dict
|
from typing import List, Optional, Union, Iterator, Dict
|
||||||
from typing_extensions import TypedDict, Literal
|
from typing_extensions import TypedDict, Literal, Annotated
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI, APIRouter
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model: str = os.environ.get("MODEL", "null")
|
model: str
|
||||||
n_ctx: int = 2048
|
n_ctx: int = 2048
|
||||||
n_batch: int = 512
|
n_batch: int = 512
|
||||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||||
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
|
|||||||
vocab_only: bool = False
|
vocab_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
router = APIRouter()
|
||||||
title="🦙 llama.cpp Python API",
|
|
||||||
version="0.0.1",
|
|
||||||
)
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
llama: llama_cpp.Llama = None
|
llama: Optional[llama_cpp.Llama] = None
|
||||||
def init_llama(settings: Settings = None):
|
|
||||||
|
|
||||||
|
def create_app(settings: Optional[Settings] = None):
|
||||||
if settings is None:
|
if settings is None:
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
app = FastAPI(
|
||||||
|
title="🦙 llama.cpp Python API",
|
||||||
|
version="0.0.1",
|
||||||
|
)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
app.include_router(router)
|
||||||
global llama
|
global llama
|
||||||
llama = llama_cpp.Llama(
|
llama = llama_cpp.Llama(
|
||||||
settings.model,
|
model_path=settings.model,
|
||||||
f16_kv=settings.f16_kv,
|
f16_kv=settings.f16_kv,
|
||||||
use_mlock=settings.use_mlock,
|
use_mlock=settings.use_mlock,
|
||||||
use_mmap=settings.use_mmap,
|
use_mmap=settings.use_mmap,
|
||||||
@@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
|
|||||||
if settings.cache:
|
if settings.cache:
|
||||||
cache = llama_cpp.LlamaCache()
|
cache = llama_cpp.LlamaCache()
|
||||||
llama.set_cache(cache)
|
llama.set_cache(cache)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
llama_lock = Lock()
|
llama_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_llama():
|
def get_llama():
|
||||||
with llama_lock:
|
with llama_lock:
|
||||||
yield llama
|
yield llama
|
||||||
@@ -117,8 +125,6 @@ repeat_penalty_field = Field(
|
|||||||
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
|
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]] = Field(
|
prompt: Union[str, List[str]] = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -162,7 +168,7 @@ class CreateCompletionRequest(BaseModel):
|
|||||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
response_model=CreateCompletionResponse,
|
response_model=CreateCompletionResponse,
|
||||||
)
|
)
|
||||||
@@ -204,7 +210,7 @@ class CreateEmbeddingRequest(BaseModel):
|
|||||||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
response_model=CreateEmbeddingResponse,
|
response_model=CreateEmbeddingResponse,
|
||||||
)
|
)
|
||||||
@@ -257,7 +263,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||||||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@router.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
response_model=CreateChatCompletionResponse,
|
response_model=CreateChatCompletionResponse,
|
||||||
)
|
)
|
||||||
@@ -306,7 +312,7 @@ class ModelList(TypedDict):
|
|||||||
GetModelResponse = create_model_from_typeddict(ModelList)
|
GetModelResponse = create_model_from_typeddict(ModelList)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=GetModelResponse)
|
@router.get("/v1/models", response_model=GetModelResponse)
|
||||||
def get_models() -> ModelList:
|
def get_models() -> ModelList:
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
|
|||||||
## Set up mock function
|
## Set up mock function
|
||||||
def mock_eval(*args, **kwargs):
|
def mock_eval(*args, **kwargs):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
return (llama_cpp.c_float * n_vocab)(
|
||||||
|
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
@@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
|
|||||||
def test_llama_pickle():
|
def test_llama_pickle():
|
||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
fp = tempfile.TemporaryFile()
|
fp = tempfile.TemporaryFile()
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
pickle.dump(llama, fp)
|
pickle.dump(llama, fp)
|
||||||
@@ -101,6 +104,7 @@ def test_llama_pickle():
|
|||||||
|
|
||||||
assert llama.detokenize(llama.tokenize(text)) == text
|
assert llama.detokenize(llama.tokenize(text)) == text
|
||||||
|
|
||||||
|
|
||||||
def test_utf8(monkeypatch):
|
def test_utf8(monkeypatch):
|
||||||
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
|
||||||
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
|
||||||
@@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def mock_get_logits(*args, **kwargs):
|
def mock_get_logits(*args, **kwargs):
|
||||||
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
|
return (llama_cpp.c_float * n_vocab)(
|
||||||
|
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
|
||||||
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
|
||||||
@@ -143,11 +149,13 @@ def test_utf8(monkeypatch):
|
|||||||
|
|
||||||
def test_llama_server():
|
def test_llama_server():
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from llama_cpp.server.app import app, init_llama, Settings
|
from llama_cpp.server.app import create_app, Settings
|
||||||
s = Settings()
|
|
||||||
s.model = MODEL
|
settings = Settings(
|
||||||
s.vocab_only = True
|
model=MODEL,
|
||||||
init_llama(s)
|
vocab_only=True,
|
||||||
|
)
|
||||||
|
app = create_app(settings)
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
response = client.get("/v1/models")
|
response = client.get("/v1/models")
|
||||||
assert response.json() == {
|
assert response.json() == {
|
||||||
|
|||||||
Reference in New Issue
Block a user