Files
exo/exo/topology/partitioning_strategy.py
2024-08-22 14:12:00 +01:00

41 lines
1.2 KiB
Python

from abc import ABC, abstractmethod
from typing import List
from dataclasses import dataclass
from .topology import Topology
from exo.inference.shard import Shard
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
@dataclass
class Partition:
node_id: str
start: float
end: float
class PartitioningStrategy(ABC):
@abstractmethod
def partition(self, topology: Topology) -> List[Partition]:
pass
def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
shards = []
for i, partition in enumerate(partitions):
start_layer = int(partition.start*num_layers)
end_layer = int(partition.end*num_layers) - 1
# Ensure the last partition covers up to num_layers - 1
if i == len(partitions) - 1:
end_layer = num_layers - 1
# Ensure no empty shards
if start_layer <= end_layer:
shards.append(Shard(model_id, start_layer, end_layer, num_layers))
# Ensure full coverage
if shards and shards[-1].end_layer < num_layers - 1:
shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
return shards