mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #763 from deftdawg/amdgpu
AMD/ROCm: Changes required to detect and inference on AMD GPUs
This commit is contained in:
@@ -198,22 +198,19 @@ async def linux_device_capabilities() -> DeviceCapabilities:
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
elif Device.DEFAULT == "AMD":
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
import pyamdgpuinfo
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
gpu_raw_info = pyamdgpuinfo.get_gpu(0)
|
||||
gpu_name = gpu_raw_info.name
|
||||
gpu_memory_info = gpu_raw_info.memory_info["vram_size"]
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Linux Box ({gpu_name})",
|
||||
model="Linux Box (" + gpu_name + ")",
|
||||
chip=gpu_name,
|
||||
memory=gpu_memory_info // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user