Skip to content

Commit ac8fa45

Browse files
[Backend] Add ROCm support (#652)
Depending on this PR: apache/tvm#15464 On 7900 xtx. ROCm 5.6 ``` ~/mlc-llm (rocm ✔) ./build/mlc_chat_cli --local-id Llama-2-7b-chat-hf-q4f16_1 Use MLC config: "/home/bohan/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1/params/mlc-chat-config.json" Use model weights: "/home/bohan/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1/params/ndarray-cache.json" Use model library: "/home/bohan/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-rocm.so" You can use the following special commands: /help print the special commands /exit quit the cli /stats print out the latest stats (token/sec) /reset restart a fresh chat /reload [local_id] reload model `local_id` from disk, or reload the current model if `local_id` is not specified Loading model... Loading finished Running system prompts... System prompts finished [INST]: Hi [/INST]: Hello! It's nice to meet you. I'm here to help you with any questions or tasks you may have, while always being safe and respectful. Is there something specific you would like to know or discuss? Please feel free to ask me anything, and I will do my best to provide a helpful and positive response. [INST]: /stats prefill: 507.3 tok/s, decode: 92.0 tok/s ``` ``` ~/mlc-llm (rocm ✗) ./build/mlc_chat_cli --local-id Llama-2-13b-chat-hf-q4f16_1 Use MLC config: "/home/bohan/mlc-llm/dist/Llama-2-13b-chat-hf-q4f16_1/params/mlc-chat-config.json" Use model weights: "/home/bohan/mlc-llm/dist/Llama-2-13b-chat-hf-q4f16_1/params/ndarray-cache.json" Use model library: "/home/bohan/mlc-llm/dist/Llama-2-13b-chat-hf-q4f16_1/Llama-2-13b-chat-hf-q4f16_1-rocm.so" You can use the following special commands: /help print the special commands /exit quit the cli /stats print out the latest stats (token/sec) /reset restart a fresh chat /reload [local_id] reload model `local_id` from disk, or reload the current model if `local_id` is not specified Loading model... Loading finished Running system prompts... System prompts finished [INST]: Hi [/INST]: Hello! I'm here to assist you with any questions you may have. Please keep in mind that I strive to provide safe and positive responses that are free of harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. If a question does not make sense or is not factually coherent, I will do my best to explain why instead of providing an incorrect answer. If I don't know the answer to a question, I will not provide false information. Is there anything specific you would like to know or discuss? [INST]: /stats prefill: 495.7 tok/s, decode: 69.0 tok/s ```
1 parent 3c53eeb commit ac8fa45

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

cpp/cli_main.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ std::string DetectDeviceName(std::string device_name) {
4242
if (DeviceAPI::Get(DLDevice{kDLMetal, 0}, allow_missing)) {
4343
return "metal";
4444
}
45+
if (DeviceAPI::Get(DLDevice{kDLROCM, 0}, allow_missing)) {
46+
return "rocm";
47+
}
4548
if (DeviceAPI::Get(DLDevice{kDLVulkan, 0}, allow_missing)) {
4649
return "vulkan";
4750
}
@@ -56,6 +59,7 @@ std::string DetectDeviceName(std::string device_name) {
5659
DLDevice GetDevice(const std::string& device_name, int device_id) {
5760
if (device_name == "cuda") return DLDevice{kDLCUDA, device_id};
5861
if (device_name == "metal") return DLDevice{kDLMetal, device_id};
62+
if (device_name == "rocm") return DLDevice{kDLROCM, device_id};
5963
if (device_name == "vulkan") return DLDevice{kDLVulkan, device_id};
6064
if (device_name == "opencl") return DLDevice{kDLOpenCL, device_id};
6165
LOG(FATAL) << "Do not recognize device name " << device_name;

mlc_llm/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,20 @@ def _detect_local_cuda():
308308
)
309309

310310

311+
def _detect_local_rocm():
312+
dev = tvm.rocm()
313+
if not dev.exist:
314+
return None
315+
return tvm.target.Target(
316+
{
317+
"kind": "rocm",
318+
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
319+
"max_threads_per_block": dev.max_threads_per_block,
320+
"thread_warp_size": dev.warp_size,
321+
}
322+
)
323+
324+
311325
def _detect_local_vulkan():
312326
dev = tvm.vulkan()
313327
if not dev.exist:
@@ -336,6 +350,7 @@ def _detect_local_opencl():
336350
def detect_local_target():
337351
for method in [
338352
_detect_local_metal,
353+
_detect_local_rocm,
339354
_detect_local_cuda,
340355
_detect_local_vulkan,
341356
_detect_local_opencl,

tests/debug/compare_lib.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def compare(
5252
super().compare(name, ref_args, new_args, ret_indices)
5353

5454
if self.time_eval and name not in self.time_eval_results:
55-
res = self.mod.time_evaluator(name, self.device, number=100, repeat=3)(
56-
*new_args
57-
)
55+
res = self.mod.time_evaluator(
56+
name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6
57+
)(*new_args)
5858
self.time_eval_results[name] = (res.mean, 1)
5959
print(f"Time-eval result {name} on {self.device}: {res}")
6060

@@ -212,6 +212,8 @@ def _parse_args():
212212
parsed.primary_device = "cuda"
213213
elif tvm.metal().exist:
214214
parsed.primary_device = "metal"
215+
elif tvm.rocm().exist:
216+
parsed.primary_device = "rocm"
215217
else:
216218
raise ValueError("Cannot auto deduce device-name, please set it")
217219
return parsed

0 commit comments

Comments
 (0)