mirror of
https://github.com/abetlen/llama-cpp-python.git
synced 2023-09-07 17:34:22 +03:00
Refactor server to use factory
This commit is contained in:
@@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
import os
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import app, init_llama
|
||||
from llama_cpp.server.app import create_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_llama()
|
||||
app = create_app()
|
||||
|
||||
uvicorn.run(
|
||||
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))
|
||||
|
||||
@@ -2,18 +2,18 @@ import os
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Union, Iterator, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
from typing_extensions import TypedDict, Literal, Annotated
|
||||
|
||||
import llama_cpp
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Depends, FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str = os.environ.get("MODEL", "null")
|
||||
model: str
|
||||
n_ctx: int = 2048
|
||||
n_batch: int = 512
|
||||
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
|
||||
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
|
||||
vocab_only: bool = False
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
llama: llama_cpp.Llama = None
|
||||
def init_llama(settings: Settings = None):
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
||||
|
||||
def create_app(settings: Optional[Settings] = None):
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
app = FastAPI(
|
||||
title="🦙 llama.cpp Python API",
|
||||
version="0.0.1",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(router)
|
||||
global llama
|
||||
llama = llama_cpp.Llama(
|
||||
settings.model,
|
||||
model_path=settings.model,
|
||||
f16_kv=settings.f16_kv,
|
||||
use_mlock=settings.use_mlock,
|
||||
use_mmap=settings.use_mmap,
|
||||
@@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
|
||||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache()
|
||||
llama.set_cache(cache)
|
||||
return app
|
||||
|
||||
|
||||
llama_lock = Lock()
|
||||
|
||||
|
||||
def get_llama():
|
||||
with llama_lock:
|
||||
yield llama
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]]
|
||||
suffix: Optional[str] = Field(None)
|
||||
@@ -102,7 +111,7 @@ class CreateCompletionRequest(BaseModel):
|
||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
)
|
||||
@@ -148,7 +157,7 @@ class CreateEmbeddingRequest(BaseModel):
|
||||
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
response_model=CreateEmbeddingResponse,
|
||||
)
|
||||
@@ -202,7 +211,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=CreateChatCompletionResponse,
|
||||
)
|
||||
@@ -256,7 +265,7 @@ class ModelList(TypedDict):
|
||||
GetModelResponse = create_model_from_typeddict(ModelList)
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=GetModelResponse)
|
||||
@router.get("/v1/models", response_model=GetModelResponse)
|
||||
def get_models() -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
|
||||
Reference in New Issue
Block a user