mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
from dataclasses import dataclass
|
|
from .topology import Topology
|
|
from exo.inference.shard import Shard
|
|
|
|
|
|
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
|
|
@dataclass
|
|
class Partition:
|
|
node_id: str
|
|
start: float
|
|
end: float
|
|
|
|
|
|
class PartitioningStrategy(ABC):
|
|
@abstractmethod
|
|
def partition(self, topology: Topology) -> List[Partition]:
|
|
pass
|
|
|
|
|
|
def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
|
|
shards = []
|
|
for i, partition in enumerate(partitions):
|
|
start_layer = int(partition.start*num_layers)
|
|
end_layer = int(partition.end*num_layers) - 1
|
|
|
|
# Ensure the last partition covers up to num_layers - 1
|
|
if i == len(partitions) - 1:
|
|
end_layer = num_layers - 1
|
|
|
|
# Ensure no empty shards
|
|
if start_layer <= end_layer:
|
|
shards.append(Shard(model_id, start_layer, end_layer, num_layers))
|
|
|
|
# Ensure full coverage
|
|
if shards and shards[-1].end_layer < num_layers - 1:
|
|
shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
|
|
|
|
return shards
|