mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
download status in parallel, support async ensure shard with using shard_downloader instead
This commit is contained in:
@@ -245,7 +245,7 @@ class ChatGPTAPI:
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
for model_name, pretty in pretty_name.items():
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
@@ -273,6 +273,12 @@ class ChatGPTAPI:
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[
|
||||
process_model(model_name, pretty)
|
||||
for model_name, pretty in pretty_name.items()
|
||||
])
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
@@ -562,7 +568,7 @@ class ChatGPTAPI:
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
|
||||
Reference in New Issue
Block a user