diff --git a/README.md b/README.md index 80518f6..9ee9199 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ This package is under active development and I welcome any contributions. To get started, clone the repository and install the package in development mode: ```bash -git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git +git clone --recurse-submodules git@github.com:abetlen/llama-cpp-python.git # Will need to be re-run any time vendor/llama.cpp is updated python3 setup.py develop ``` diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index bec5be7..d201013 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -306,7 +306,7 @@ class Llama: llama_cpp.llama_sample_typical( ctx=self.ctx, candidates=llama_cpp.ctypes.pointer(candidates), - p=llama_cpp.c_float(1.0) + p=llama_cpp.c_float(1.0), ) llama_cpp.llama_sample_top_p( ctx=self.ctx, @@ -637,10 +637,7 @@ class Llama: self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] - all_logprobs = [ - Llama._logits_to_logprobs(row) - for row in self.eval_logits - ] + all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index f57d68c..4fbee37 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs. import os import uvicorn -from llama_cpp.server.app import app, init_llama +from llama_cpp.server.app import create_app if __name__ == "__main__": - init_llama() + app = create_app() uvicorn.run( app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000)) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 886ee6d..ef8aa4e 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -2,18 +2,18 @@ import os import json from threading import Lock from typing import List, Optional, Union, Iterator, Dict -from typing_extensions import TypedDict, Literal +from typing_extensions import TypedDict, Literal, Annotated import llama_cpp -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, APIRouter from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict from sse_starlette.sse import EventSourceResponse class Settings(BaseSettings): - model: str = os.environ.get("MODEL", "null") + model: str n_ctx: int = 2048 n_batch: int = 512 n_threads: int = max((os.cpu_count() or 2) // 2, 1) @@ -27,25 +27,29 @@ class Settings(BaseSettings): vocab_only: bool = False -app = FastAPI( - title="🦙 llama.cpp Python API", - version="0.0.1", -) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +router = APIRouter() -llama: llama_cpp.Llama = None -def init_llama(settings: Settings = None): +llama: Optional[llama_cpp.Llama] = None + + +def create_app(settings: Optional[Settings] = None): if settings is None: settings = Settings() + app = FastAPI( + title="🦙 llama.cpp Python API", + version="0.0.1", + ) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.include_router(router) global llama llama = llama_cpp.Llama( - settings.model, + model_path=settings.model, f16_kv=settings.f16_kv, use_mlock=settings.use_mlock, use_mmap=settings.use_mmap, @@ -60,8 +64,12 @@ def init_llama(settings: Settings = None): if settings.cache: cache = llama_cpp.LlamaCache() llama.set_cache(cache) + return app + llama_lock = Lock() + + def get_llama(): with llama_lock: yield llama @@ -117,8 +125,6 @@ repeat_penalty_field = Field( "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient." ) - - class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] = Field( default="", @@ -162,7 +168,7 @@ class CreateCompletionRequest(BaseModel): CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) -@app.post( +@router.post( "/v1/completions", response_model=CreateCompletionResponse, ) @@ -204,7 +210,7 @@ class CreateEmbeddingRequest(BaseModel): CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding) -@app.post( +@router.post( "/v1/embeddings", response_model=CreateEmbeddingResponse, ) @@ -257,7 +263,7 @@ class CreateChatCompletionRequest(BaseModel): CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion) -@app.post( +@router.post( "/v1/chat/completions", response_model=CreateChatCompletionResponse, ) @@ -306,7 +312,7 @@ class ModelList(TypedDict): GetModelResponse = create_model_from_typeddict(ModelList) -@app.get("/v1/models", response_model=GetModelResponse) +@router.get("/v1/models", response_model=GetModelResponse) def get_models() -> ModelList: return { "object": "list", diff --git a/tests/test_llama.py b/tests/test_llama.py index 2bf38b3..b3426b8 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch): ## Set up mock function def mock_eval(*args, **kwargs): return 0 - + def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) + return (llama_cpp.c_float * n_vocab)( + *[llama_cpp.c_float(0) for _ in range(n_vocab)] + ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) @@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch): def test_llama_pickle(): import pickle import tempfile + fp = tempfile.TemporaryFile() llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) pickle.dump(llama, fp) @@ -101,6 +104,7 @@ def test_llama_pickle(): assert llama.detokenize(llama.tokenize(text)) == text + def test_utf8(monkeypatch): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True) n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx)) @@ -110,7 +114,9 @@ def test_utf8(monkeypatch): return 0 def mock_get_logits(*args, **kwargs): - return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)]) + return (llama_cpp.c_float * n_vocab)( + *[llama_cpp.c_float(0) for _ in range(n_vocab)] + ) monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval) monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits) @@ -143,11 +149,13 @@ def test_utf8(monkeypatch): def test_llama_server(): from fastapi.testclient import TestClient - from llama_cpp.server.app import app, init_llama, Settings - s = Settings() - s.model = MODEL - s.vocab_only = True - init_llama(s) + from llama_cpp.server.app import create_app, Settings + + settings = Settings( + model=MODEL, + vocab_only=True, + ) + app = create_app(settings) client = TestClient(app) response = client.get("/v1/models") assert response.json() == {