From e1b5b9bb0422b4536fa949166265ebdfcff11362 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 5 Apr 2023 14:44:26 -0400 Subject: [PATCH] Update fastapi server example --- examples/high_level_api/fastapi_server.py | 93 +++++++++++++++++++++-- 1 file changed, 87 insertions(+), 6 deletions(-) diff --git a/examples/high_level_api/fastapi_server.py b/examples/high_level_api/fastapi_server.py index 760a6ca..b7d2565 100644 --- a/examples/high_level_api/fastapi_server.py +++ b/examples/high_level_api/fastapi_server.py @@ -13,7 +13,8 @@ Then visit http://localhost:8000/docs to see the interactive API docs. """ import os import json -from typing import List, Optional, Literal, Union, Iterator +from typing import List, Optional, Literal, Union, Iterator, Dict +from typing_extensions import TypedDict import llama_cpp @@ -64,13 +65,24 @@ class CreateCompletionRequest(BaseModel): max_tokens: int = 16 temperature: float = 0.8 top_p: float = 0.95 - logprobs: Optional[int] = Field(None) echo: bool = False stop: List[str] = [] - repeat_penalty: float = 1.1 - top_k: int = 40 stream: bool = False + # ignored or currently unsupported + model: Optional[str] = Field(None) + n: Optional[int] = 1 + logprobs: Optional[int] = Field(None) + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + best_of: Optional[int] = 1 + logit_bias: Optional[Dict[str, float]] = Field(None) + user: Optional[str] = Field(None) + + # llama.cpp specific parameters + top_k: int = 40 + repeat_penalty: float = 1.1 + class Config: schema_extra = { "example": { @@ -91,7 +103,20 @@ def create_completion(request: CreateCompletionRequest): if request.stream: chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks) - return llama(**request.dict()) + return llama( + **request.dict( + exclude={ + "model", + "n", + "logprobs", + "frequency_penalty", + "presence_penalty", + "best_of", + "logit_bias", + "user", + } + ) + ) class CreateEmbeddingRequest(BaseModel): @@ -132,6 +157,16 @@ class CreateChatCompletionRequest(BaseModel): stream: bool = False stop: List[str] = [] max_tokens: int = 128 + + # ignored or currently unsupported + model: Optional[str] = Field(None) + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + logit_bias: Optional[Dict[str, float]] = Field(None) + user: Optional[str] = Field(None) + + # llama.cpp specific parameters repeat_penalty: float = 1.1 class Config: @@ -160,7 +195,16 @@ async def create_chat_completion( request: CreateChatCompletionRequest, ) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]: completion_or_chunks = llama.create_chat_completion( - **request.dict(exclude={"model"}), + **request.dict( + exclude={ + "model", + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + } + ), ) if request.stream: @@ -179,3 +223,40 @@ async def create_chat_completion( ) completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore return completion + + +class ModelData(TypedDict): + id: str + object: Literal["model"] + owned_by: str + permissions: List[str] + + +class ModelList(TypedDict): + object: Literal["list"] + data: List[ModelData] + + +GetModelResponse = create_model_from_typeddict(ModelList) + + +@app.get("/v1/models", response_model=GetModelResponse) +def get_models() -> ModelList: + return { + "object": "list", + "data": [ + { + "id": llama.model_path, + "object": "model", + "owned_by": "me", + "permissions": [], + } + ], + } + + +if __name__ == "__main__": + import os + import uvicorn + + uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))