[AMD ROCm] Enable CK backend for ROCm gfx12#2054
[AMD ROCm] Enable CK backend for ROCm gfx12#2054hyoon1 wants to merge 7 commits intoDao-AILab:mainfrom
Conversation
|
Cc @rocking5566 |
|
I think gfx11 and gfx12 are supported by CK at approximately the same time based on the commit history of CK repo? maybe also add gfx11 archs? |
f971ee2 to
9791647
Compare
9791647 to
0290691
Compare
csrc/flash_attn_ck/mha_bwd.cpp
Outdated
| at::cuda::CUDAGuard device_guard{q.device()}; | ||
|
|
||
| auto opts = q.options(); | ||
| // gfx12 deterministic bwd is unstable; always fall back to nondeterministic there. |
There was a problem hiding this comment.
I suggest to add TORCH_CHECK to warn the user rather than switching to non deterministic automatically
| at::cuda::CUDAGuard device_guard{q.device()}; | ||
|
|
||
| auto opts = q.options(); | ||
| // gfx12 deterministic bwd is unstable; always fall back to nondeterministic there. |
There was a problem hiding this comment.
I suggest to add TORCH_CHECK to warn the user rather than switching to non deterministic automatically
flash_attn/flash_attn_interface.py
Outdated
| return_softmax, | ||
| is_grad_enabled, | ||
| ): | ||
| deterministic = _disable_gfx12_deterministic(deterministic, qkv.device) |
There was a problem hiding this comment.
I thought not to change the parameter inside the API, just assert and warn the user
There was a problem hiding this comment.
Updated. No longer mutate deterministic in the Python API. Instead, the C++ CK backward now uses TORCH_CHECK to assert and surface an error when deterministic=True on gfx12.
tests/test_flash_attn_ck.py
Outdated
|
|
||
| g = torch.randn_like(out) | ||
| if is_bwd_hdim_supported(d): | ||
| if is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic): |
There was a problem hiding this comment.
Use this function
def is_bwd_supported(d): return is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic)
|
Could you revise the supported GPU for Composable Kernel Backend in the README? |
|
LGTM |
|
I just wanted to say thank you @hyoon1 . I have been fighting with building a test bed for ROCm on my 9070XT desktop to figure out if I wanted to get a Strix Halo for dedicated AI stuff. After quite a few days fighting with ROCm dependency hell, I finally hit the mark with your pull request! vLLM params docker run -it --rm \
--device=/dev/kfd --device=/dev/dri \
--shm-size=16gb \
-e VLLM_USE_TRITON_FLASH_ATTN=True \
-p 8000:8000 vllm-rocm \
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--gpu-memory-utilization 0.8 \
--max-model-len 4096 \
--trust-remote_codevLLM bench params vllm bench serve \
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--dataset-name random \
--random-input-len 1024 \
--random-output-len 512 \
--request-rate 10 \
--num-prompts 200results: A few days ago I was capped at approximately 32 Tokens/sec. Not bad for 16gb card. Build process: |
|
@tridao could you please help to merge this PR as many users are trying to use CK-based FA on gfx12? Thanks. |
|
Closing this PR in favor of #2400 |
This extends #2052 which updated to the latest Composable Kernel version. The latest CK now supports gfx12 architectures, but the CK kernel generator needs explicit target specification to generate kernels for these GPUs.