diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index f084f93e41..ab4713f48d 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -42,6 +42,9 @@ std::string DetectDeviceName(std::string device_name) { if (DeviceAPI::Get(DLDevice{kDLMetal, 0}, allow_missing)) { return "metal"; } + if (DeviceAPI::Get(DLDevice{kDLROCM, 0}, allow_missing)) { + return "rocm"; + } if (DeviceAPI::Get(DLDevice{kDLVulkan, 0}, allow_missing)) { return "vulkan"; } @@ -56,6 +59,7 @@ std::string DetectDeviceName(std::string device_name) { DLDevice GetDevice(const std::string& device_name, int device_id) { if (device_name == "cuda") return DLDevice{kDLCUDA, device_id}; if (device_name == "metal") return DLDevice{kDLMetal, device_id}; + if (device_name == "rocm") return DLDevice{kDLROCM, device_id}; if (device_name == "vulkan") return DLDevice{kDLVulkan, device_id}; if (device_name == "opencl") return DLDevice{kDLOpenCL, device_id}; LOG(FATAL) << "Do not recognize device name " << device_name; diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index c1bdda9777..0713cc7180 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -308,6 +308,20 @@ def _detect_local_cuda(): ) +def _detect_local_rocm(): + dev = tvm.rocm() + if not dev.exist: + return None + return tvm.target.Target( + { + "kind": "rocm", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + } + ) + + def _detect_local_vulkan(): dev = tvm.vulkan() if not dev.exist: @@ -336,6 +350,7 @@ def _detect_local_opencl(): def detect_local_target(): for method in [ _detect_local_metal, + _detect_local_rocm, _detect_local_cuda, _detect_local_vulkan, _detect_local_opencl, diff --git a/tests/debug/compare_lib.py b/tests/debug/compare_lib.py index 3dcf6fc389..9c2e35f014 100644 --- a/tests/debug/compare_lib.py +++ b/tests/debug/compare_lib.py @@ -52,9 +52,9 @@ def compare( super().compare(name, ref_args, new_args, ret_indices) if self.time_eval and name not in self.time_eval_results: - res = self.mod.time_evaluator(name, self.device, number=100, repeat=3)( - *new_args - ) + res = self.mod.time_evaluator( + name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6 + )(*new_args) self.time_eval_results[name] = (res.mean, 1) print(f"Time-eval result {name} on {self.device}: {res}") @@ -212,6 +212,8 @@ def _parse_args(): parsed.primary_device = "cuda" elif tvm.metal().exist: parsed.primary_device = "metal" + elif tvm.rocm().exist: + parsed.primary_device = "rocm" else: raise ValueError("Cannot auto deduce device-name, please set it") return parsed