download status in parallel, support async ensure shard with using shard_downloader instead

This commit is contained in:
Alex Cheema
2025-01-05 02:31:59 +00:00
parent 7b1656140e
commit 8c191050a2

View File

@@ -245,7 +245,7 @@ class ChatGPTAPI:
)
await response.prepare(request)
for model_name, pretty in pretty_name.items():
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
@@ -273,6 +273,12 @@ class ChatGPTAPI:
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
await response.write(b"data: [DONE]\n\n")
return response
@@ -562,7 +568,7 @@ class ChatGPTAPI:
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard))
return web.json_response({
"status": "success",