Refactor server to use factory

This commit is contained in:
Andrei Betlen
2023-05-01 22:38:46 -04:00
parent dd9ad1c759
commit 9eafc4c49a
3 changed files with 47 additions and 31 deletions

View File

@@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
## Set up mock function
def mock_eval(*args, **kwargs):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -88,6 +90,7 @@ def test_llama_patch(monkeypatch):
def test_llama_pickle():
import pickle
import tempfile
fp = tempfile.TemporaryFile()
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
pickle.dump(llama, fp)
@@ -101,6 +104,7 @@ def test_llama_pickle():
assert llama.detokenize(llama.tokenize(text)) == text
def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@@ -110,7 +114,9 @@ def test_utf8(monkeypatch):
return 0
def mock_get_logits(*args, **kwargs):
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
return (llama_cpp.c_float * n_vocab)(
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -143,11 +149,12 @@ def test_utf8(monkeypatch):
def test_llama_server():
from fastapi.testclient import TestClient
from llama_cpp.server.app import app, init_llama, Settings
s = Settings()
s.model = MODEL
s.vocab_only = True
init_llama(s)
from llama_cpp.server.app import create_app, Settings
settings = Settings()
settings.model = MODEL
settings.vocab_only = True
app = create_app(settings)
client = TestClient(app)
response = client.get("/v1/models")
assert response.json() == {