171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
from typing import List, Optional
|
|
from pydantic import BaseModel
|
|
import time
|
|
import json
|
|
import asyncio
|
|
import uuid
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
app = FastAPI(title="OpenAI-compatible API")
|
|
|
|
# Configuration for available models
|
|
AVAILABLE_MODELS = [
|
|
{
|
|
"id": "dummy-model",
|
|
"object": "model",
|
|
"created": 1686935002,
|
|
"owned_by": "alihan",
|
|
"permission": [
|
|
{
|
|
"id": "modeldummy-" + str(uuid.uuid4())[:8],
|
|
"object": "model_permission",
|
|
"created": int(time.time()),
|
|
"allow_create_engine": False,
|
|
"allow_sampling": True,
|
|
"allow_logprobs": True,
|
|
"allow_search_indices": False,
|
|
"allow_view": True,
|
|
"allow_fine_tuning": False,
|
|
"organization": "*",
|
|
"group": None,
|
|
"is_blocking": False
|
|
}
|
|
],
|
|
"root": "dummy-model",
|
|
"parent": None
|
|
},
|
|
]
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[ChatMessage]
|
|
temperature: Optional[float] = 1.0
|
|
stream: Optional[bool] = False
|
|
max_tokens: Optional[int] = None
|
|
|
|
|
|
async def process_with_your_model(messages: List[ChatMessage], model: str) -> str:
|
|
"""
|
|
Replace this with your actual model processing logic.
|
|
You might want to route to different models based on the model parameter.
|
|
"""
|
|
last_user_message = next((msg.content for msg in reversed(messages) if msg.role == "user"), "")
|
|
return f"Response from {model}: {last_user_message}"
|
|
|
|
|
|
def generate_id() -> str:
|
|
return str(uuid.uuid4())[:8]
|
|
|
|
|
|
async def stream_response(content: str, model: str):
|
|
words = content.split()
|
|
|
|
for i, word in enumerate(words):
|
|
chunk = {
|
|
"id": f"chatcmpl-{generate_id()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": word + " "
|
|
},
|
|
"finish_reason": None
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
await asyncio.sleep(0.1)
|
|
|
|
final_chunk = {
|
|
"id": f"chatcmpl-{generate_id()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "stop"
|
|
}]
|
|
}
|
|
yield f"data: {json.dumps(final_chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
"""List all available models"""
|
|
return {
|
|
"object": "list",
|
|
"data": AVAILABLE_MODELS
|
|
}
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get details of a specific model"""
|
|
model = next((m for m in AVAILABLE_MODELS if m["id"] == model_id), None)
|
|
if not model:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
|
return model
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
# Validate model exists
|
|
if not any(model["id"] == request.model for model in AVAILABLE_MODELS):
|
|
raise HTTPException(status_code=404, detail=f"Model {request.model} not found")
|
|
|
|
response_content = await process_with_your_model(request.messages, request.model)
|
|
|
|
if not request.stream:
|
|
return {
|
|
"id": f"chatcmpl-{generate_id()}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response_content
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": len(" ".join([msg.content for msg in request.messages]).split()),
|
|
"completion_tokens": len(response_content.split()),
|
|
"total_tokens": len(" ".join([msg.content for msg in request.messages]).split()) + len(
|
|
response_content.split())
|
|
}
|
|
}
|
|
else:
|
|
return StreamingResponse(
|
|
stream_response(response_content, request.model),
|
|
media_type="text/event-stream"
|
|
)
|
|
|
|
|
|
# Optional: Add a root endpoint that redirects to documentation
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "OpenAI-compatible API server. Visit /docs for documentation."}
|
|
|
|
|
|
# Optional: Add a health check endpoint
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |