Skip to content

[AMD ROCm] Enable CK backend for ROCm gfx12#2054

Closed
hyoon1 wants to merge 7 commits intoDao-AILab:mainfrom
hyoon1:enable-ck-gfx12
Closed

[AMD ROCm] Enable CK backend for ROCm gfx12#2054
hyoon1 wants to merge 7 commits intoDao-AILab:mainfrom
hyoon1:enable-ck-gfx12

Conversation

@hyoon1
Copy link
Contributor

@hyoon1 hyoon1 commented Dec 8, 2025

This extends #2052 which updated to the latest Composable Kernel version. The latest CK now supports gfx12 architectures, but the CK kernel generator needs explicit target specification to generate kernels for these GPUs.

  • Added gfx1200 and gfx1201 to allowed architectures
  • Added GPU_ARCHS environment variable support for explicit CK target specification
  • Auto-detects GPU when GPU_ARCHS not set
  • Modified CK generator to pass --targets flag with specified architectures
  • Disable CK deterministic backward on gfx12 (force nondeterministic kernels and skip deterministic CK tests there) because the deterministic path is unstable on these GPUs (GPU hang occurs).

@tridao
Copy link
Member

tridao commented Dec 10, 2025

Cc @rocking5566

@Logiquo
Copy link

Logiquo commented Jan 3, 2026

I think gfx11 and gfx12 are supported by CK at approximately the same time based on the commit history of CK repo? maybe also add gfx11 archs?

at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
// gfx12 deterministic bwd is unstable; always fall back to nondeterministic there.
Copy link
Contributor

@rocking5566 rocking5566 Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to add TORCH_CHECK to warn the user rather than switching to non deterministic automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
// gfx12 deterministic bwd is unstable; always fall back to nondeterministic there.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to add TORCH_CHECK to warn the user rather than switching to non deterministic automatically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return_softmax,
is_grad_enabled,
):
deterministic = _disable_gfx12_deterministic(deterministic, qkv.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought not to change the parameter inside the API, just assert and warn the user

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. No longer mutate deterministic in the Python API. Instead, the C++ CK backward now uses TORCH_CHECK to assert and surface an error when deterministic=True on gfx12.


g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic):
Copy link
Contributor

@rocking5566 rocking5566 Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this function
def is_bwd_supported(d): return is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@rocking5566
Copy link
Contributor

Could you revise the supported GPU for Composable Kernel Backend in the README?

@rocking5566
Copy link
Contributor

LGTM
@tridao could you help to merge?

Copy link
Contributor

@rocking5566 rocking5566 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bluefalcon13
Copy link

bluefalcon13 commented Feb 6, 2026

I just wanted to say thank you @hyoon1 . I have been fighting with building a test bed for ROCm on my 9070XT desktop to figure out if I wanted to get a Strix Halo for dedicated AI stuff. After quite a few days fighting with ROCm dependency hell, I finally hit the mark with your pull request!

vLLM params

docker run -it --rm  \
   --device=/dev/kfd --device=/dev/dri \
    --shm-size=16gb \
    -e VLLM_USE_TRITON_FLASH_ATTN=True \
    -p 8000:8000     vllm-rocm \
    --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
    --gpu-memory-utilization 0.8 \
    --max-model-len 4096 \
    --trust-remote_code

vLLM bench params

vllm bench serve \
    --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
    --dataset-name random \
    --random-input-len 1024 \
    --random-output-len 512 \
    --request-rate 10 \
    --num-prompts 200

results:

Traffic request rate: 10.0
Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: None
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [06:52<00:00,  2.06s/it]
tip: install termplotlib and gnuplot to plot the metrics
============ Serving Benchmark Result ============
Successful requests:                     200       
Failed requests:                         0         
Request rate configured (RPS):           10.00     
Benchmark duration (s):                  412.70    
Total input tokens:                      204600    
Total generated tokens:                  102400    
Request throughput (req/s):              0.48      
Output token throughput (tok/s):         248.12    
Peak output token throughput (tok/s):    325.00    
Peak concurrent requests:                194.00    
Total token throughput (tok/s):          743.88    
---------------Time to First Token----------------
Mean TTFT (ms):                          187872.60 
Median TTFT (ms):                        187549.63 
P99 TTFT (ms):                           372961.62 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.94     
Median TPOT (ms):                        20.84     
P99 TPOT (ms):                           31.85     
---------------Inter-token Latency----------------
Mean ITL (ms):                           22.94     
Median ITL (ms):                         20.25     
P99 ITL (ms):                            21.11     
==================================================

A few days ago I was capped at approximately 32 Tokens/sec. Not bad for 16gb card.

Build process:
https://github.com/bluefalcon13/vllm-rocm.git

@liangshen68
Copy link

@tridao could you please help to merge this PR as many users are trying to use CK-based FA on gfx12? Thanks.

@rocking5566
Copy link
Contributor

Hi @hyoon1, we've opened #2400 which is a more complete version of this PR (includes gfx11 support, LLC head grouping, and improvements from code review). This PR can be closed in favor of #2400.

@hyoon1
Copy link
Contributor Author

hyoon1 commented Mar 26, 2026

Closing this PR in favor of #2400

@hyoon1 hyoon1 closed this Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants