Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

from functools import lru_cache
import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional

import torch
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
amdsmi_init, amdsmi_shut_down)

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -53,6 +56,41 @@
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}

# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
if "HIP_VISIBLE_DEVICES" in os.environ:
val = os.environ["HIP_VISIBLE_DEVICES"]
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
assert val == cuda_val
else:
os.environ["CUDA_VISIBLE_DEVICES"] = val

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
amdsmi_init()
try:
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()

return wrapper


def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id


class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
Expand Down Expand Up @@ -96,13 +134,12 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
return DeviceCapability(major=major, minor=minor)

@classmethod
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
# NOTE: When using V1 this function is called when overriding the
# engine args. Calling torch.cuda.get_device_name(device_id) here
# will result in the ROCm context being initialized before other
# processes can be created.
return "AMD"
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
return amdsmi_get_gpu_asic_info(handle)["market_name"]

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
Expand Down