mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
remove redundant sample_logits, put back opaque status for process_prompt so we have a way of preemptively starting downloads
This commit is contained in:
@@ -13,27 +13,6 @@ import asyncio
|
||||
from collections import OrderedDict
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
def sample_logits(
|
||||
logits: mx.array,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
) -> Tuple[mx.array, float]:
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
else:
|
||||
token = mx.random.categorical(logits*(1/temp))
|
||||
|
||||
return token
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
|
||||
@@ -70,25 +70,28 @@ class Node:
|
||||
def on_node_status(self, request_id, opaque_status):
|
||||
try:
|
||||
status_data = json.loads(opaque_status)
|
||||
if status_data.get("type", "") == "supported_inference_engines":
|
||||
status_type = status_data.get("type", "")
|
||||
if status_type == "supported_inference_engines":
|
||||
node_id = status_data.get("node_id")
|
||||
engines = status_data.get("engines", [])
|
||||
self.topology_inference_engines_pool.append(engines)
|
||||
if status_data.get("type", "") == "node_status":
|
||||
elif status_type == "node_status":
|
||||
if status_data.get("status", "").startswith("start_"):
|
||||
self.current_topology.active_node_id = status_data.get("node_id")
|
||||
elif status_data.get("status", "").startswith("end_"):
|
||||
if status_data.get("node_id") == self.current_topology.active_node_id:
|
||||
self.current_topology.active_node_id = None
|
||||
|
||||
download_progress = None
|
||||
if status_data.get("type", "") == "download_progress":
|
||||
if status_type == "download_progress":
|
||||
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
|
||||
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
|
||||
self.node_download_progress[status_data.get('node_id')] = download_progress
|
||||
|
||||
if self.topology_viz:
|
||||
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
|
||||
except Exception as e:
|
||||
if DEBUG >= 1: print(f"Error updating visualization: {e}")
|
||||
if DEBUG >= 1: print(f"Error on_node_status: {e}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
|
||||
def get_supported_inference_engines(self):
|
||||
@@ -153,10 +156,39 @@ class Node:
|
||||
request_id: Optional[str] = None,
|
||||
) -> None:
|
||||
shard = self.get_current_shard(base_shard)
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"request_id": request_id,
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
await self._process_prompt(base_shard, prompt, request_id)
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_time_ns = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_prompt",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"prompt": prompt,
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
}),
|
||||
)
|
||||
)
|
||||
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
|
||||
|
||||
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
|
||||
Reference in New Issue
Block a user