Merge branch 'main' into better-server-params-and-fields

This commit is contained in:
Andrei
2023-05-01 22:45:57 -04:00
committed by GitHub
5 changed files with 50 additions and 39 deletions

View File

@@ -306,7 +306,7 @@ class Llama:
llama_cpp.llama_sample_typical(
ctx=self.ctx,
candidates=llama_cpp.ctypes.pointer(candidates),
p=llama_cpp.c_float(1.0)
p=llama_cpp.c_float(1.0),
)
llama_cpp.llama_sample_top_p(
ctx=self.ctx,
@@ -637,10 +637,7 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [
Llama._logits_to_logprobs(row)
for row in self.eval_logits
]
all_logprobs = [Llama._logits_to_logprobs(row) for row in self.eval_logits]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):

View File

@@ -24,10 +24,10 @@ Then visit http://localhost:8000/docs to see the interactive API docs.
import os
import uvicorn
from llama_cpp.server.app import app, init_llama
from llama_cpp.server.app import create_app
if __name__ == "__main__":
init_llama()
app = create_app()
uvicorn.run(
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

View File

@@ -2,18 +2,18 @@ import os
import json
from threading import Lock
from typing import List, Optional, Union, Iterator, Dict
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal, Annotated
import llama_cpp
from fastapi import Depends, FastAPI
from fastapi import Depends, FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
from sse_starlette.sse import EventSourceResponse
class Settings(BaseSettings):
model: str = os.environ.get("MODEL", "null")
model: str
n_ctx: int = 2048
n_batch: int = 512
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@@ -27,6 +27,14 @@ class Settings(BaseSettings):
vocab_only: bool = False
router = APIRouter()
llama: Optional[llama_cpp.Llama] = None
def create_app(settings: Optional[Settings] = None):
if settings is None:
settings = Settings()
app = FastAPI(
title="🦙 llama.cpp Python API",
version="0.0.1",
@@ -38,14 +46,10 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
llama: llama_cpp.Llama = None
def init_llama(settings: Settings = None):
if settings is None:
settings = Settings()
app.include_router(router)
global llama
llama = llama_cpp.Llama(
settings.model,
model_path=settings.model,
f16_kv=settings.f16_kv,
use_mlock=settings.use_mlock,
use_mmap=settings.use_mmap,
@@ -60,8 +64,12 @@ def init_llama(settings: Settings = None):
if settings.cache:
cache = llama_cpp.LlamaCache()
llama.set_cache(cache)
return app
llama_lock = Lock()
def get_llama():
with llama_lock:
yield llama
@@ -117,8 +125,6 @@ repeat_penalty_field = Field(
"Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient."
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="",
@@ -162,7 +168,7 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
@app.post(
@router.post(
"/v1/completions",
response_model=CreateCompletionResponse,
)
@@ -204,7 +210,7 @@ class CreateEmbeddingRequest(BaseModel):
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
@app.post(
@router.post(
"/v1/embeddings",
response_model=CreateEmbeddingResponse,
)
@@ -257,7 +263,7 @@ class CreateChatCompletionRequest(BaseModel):
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
@app.post(
@router.post(
"/v1/chat/completions",
response_model=CreateChatCompletionResponse,
)
@@ -306,7 +312,7 @@ class ModelList(TypedDict):
GetModelResponse = create_model_from_typeddict(ModelList)
@app.get("/v1/models", response_model=GetModelResponse)
@router.get("/v1/models", response_model=GetModelResponse)
def get_models() -> ModelList:
return {
"object": "list",

View File

@@ -24,7 +24,9 @@ def test_llama_patch(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
@@ -101,6 +104,7 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -143,11 +149,13 @@ def test_utf8(monkeypatch):
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
from llama_cpp.server.app import create_app, Settings
settings = Settings(
model=MODEL,
vocab_only=True,
)
app = create_app(settings)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {