feat: add Azure support to OpenAIClient and update configuration

This commit is contained in:
stathi.fotiadis
2025-06-25 08:01:00 +00:00
parent 11863dcb1e
commit 9086d03b9b
3 changed files with 22 additions and 9 deletions

View File

@@ -17,10 +17,12 @@ from src.core.model_manager import model_manager
router = APIRouter()
openai_client = OpenAIClient(
config.openai_api_key, config.openai_base_url, config.request_timeout
config.openai_api_key,
config.openai_base_url,
config.request_timeout,
api_version=config.azure_api_version,
)
@router.post("/v1/messages")
async def create_message(request: ClaudeMessagesRequest, http_request: Request):
try:

View File

@@ -2,21 +2,31 @@ import asyncio
import json
from fastapi import HTTPException
from typing import Optional, AsyncGenerator, Dict, Any
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai._exceptions import APIError, RateLimitError, AuthenticationError, BadRequestError
class OpenAIClient:
"""Async OpenAI client with cancellation support."""
def __init__(self, api_key: str, base_url: str, timeout: int = 90):
def __init__(self, api_key: str, base_url: str, timeout: int = 90, api_version: Optional[str] = None):
self.api_key = api_key
self.base_url = base_url
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout
)
# Detect if using Azure and instantiate the appropriate client
if api_version:
self.client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=base_url,
api_version=api_version,
timeout=timeout
)
else:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout
)
self.active_requests: Dict[str, asyncio.Event] = {}
async def create_chat_completion(self, request: Dict[str, Any], request_id: Optional[str] = None) -> Dict[str, Any]:

View File

@@ -9,6 +9,7 @@ class Config:
raise ValueError("OPENAI_API_KEY not found in environment variables")
self.openai_base_url = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
self.azure_api_version = os.environ.get("AZURE_API_VERSION") # For Azure OpenAI
self.host = os.environ.get("HOST", "0.0.0.0")
self.port = int(os.environ.get("PORT", "8082"))
self.log_level = os.environ.get("LOG_LEVEL", "INFO")