diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 90f2d5d238..5c499d69a6 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -88,10 +88,22 @@ def is_cdna(): @functools.lru_cache(1) def is_rdna(): - """Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4).""" + """Detect RDNA consumer/workstation GPUs (RDNA2, RDNA3, RDNA3.5, RDNA4).""" return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( + # RDNA2 + "gfx1030", + "gfx1031", + "gfx1032", + # RDNA3 "gfx1100", "gfx1101", + "gfx1102", + "gfx1103", + # RDNA3.5 + "gfx1150", + "gfx1151", + "gfx1152", + # RDNA4 "gfx1200", "gfx1201", )