Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 82 files
+87 −72 CMakeLists.txt
+16 −0 CMakePresets.json
+2 −0 Dockerfile.aiter
+2 −0 Dockerfile.fa
+2 −1 Dockerfile.pytorch
+3 −9 Jenkinsfile
+15 −0 README.md
+67 −22 dispatcher/python/ctypes_utils.py
+294 −0 dispatcher/tests/test_library_caching.py
+11 −4 example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
+8 −2 example/ck_tile/01_fmha/CMakeLists.txt
+1 −1 example/ck_tile/01_fmha/README.md
+101 −12 example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
+33 −2 example/ck_tile/01_fmha/fmha_fwd.hpp
+102 −10 example/ck_tile/01_fmha/fmha_fwd_runner.hpp
+0 −46 example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh
+128 −0 example/ck_tile/52_cshuffle_lds/CMakeLists.txt
+61 −0 example/ck_tile/52_cshuffle_lds/README.md
+122 −0 example/ck_tile/52_cshuffle_lds/benchmark_cshuffle_lds.hpp
+100 −0 example/ck_tile/52_cshuffle_lds/benchmark_template.cpp.in
+3 −0 example/ck_tile/CMakeLists.txt
+5 −8 experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp
+6 −8 experimental/grouped_convolution_tile_instances/generate_instances.py
+2 −0 include/ck/host_utility/device_prop.hpp
+5 −4 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+6 −4 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
+10 −5 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+1,216 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp
+2 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp
+1 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+17 −3 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp
+167 −0 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp
+7 −143 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp
+14 −165 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+6 −146 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp
+13 −163 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+11 −0 include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp
+3 −2 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp
+207 −121 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp
+38 −0 include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp
+4 −0 include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp
+5 −0 include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp
+81 −0 include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
+16 −7 include/ck_tile/core/tensor/tensor_descriptor.hpp
+168 −9 include/ck_tile/core/tensor/tile_scatter_gather.hpp
+14 −0 include/ck_tile/core/utility/type_traits.hpp
+1 −0 include/ck_tile/ops/fmha.hpp
+17 −0 include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp
+10 −8 include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
+248 −125 include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
+6 −0 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+7 −2 include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+1 −1 include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+9 −4 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp
+45 −0 include/ck_tile/utility/tile_load_store_microkernels.hpp
+85 −0 ...ibrary/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp
+64 −4 ...ry/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp
+4 −0 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
+28 −0 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc
+2 −0 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt
+49 −0 ...e/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
+49 −0 ...ce/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
+24 −9 ...gc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp
+26 −9 ...dhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp
+25 −9 ...wgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp
+25 −9 ...ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp
+1 −0 python/ck4inductor/grouped_conv_fwd/op.py
+47 −1 script/cmake-ck-dev.sh
+28 −0 script/run_inductor_tests.sh
+7 −3 test/ck_tile/CMakeLists.txt
+6 −0 test/ck_tile/gemm_block_scale/CMakeLists.txt
+31 −0 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves_padded_stride.cpp
+7 −2 test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
+11 −0 test/grouped_convnd_bwd_data/CMakeLists.txt
+258 −0 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp
+3 −3 tutorial/ck_tile/gemm/01_naive_gemm/README.md
8 changes: 7 additions & 1 deletion csrc/cpp_itfs/mha_fwd_batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ float mha_batch_prefill(mha_batch_prefill_args args,
int head_size_q = args.hdim_q;
int head_size_v = args.hdim_v;
bool has_dropout = args.p_drop > 0.f;
auto traits = get_mha_batch_prefill_traits(head_size_q,

// The kUseGlobalLoad decision (>2GB KV cache → use `global_load_lds_*`
// instead of SRD `buffer_load_*`) is made per-arm inside the auto-generated
// dispatcher in fmha_batch_prefill_api.cpp, where each arm knows its own
// compile-time bn0 and dtype element size. The wrapper just forwards args;
// no runtime trait field for it.
auto traits = get_mha_batch_prefill_traits(head_size_q,
head_size_v,
q_dtype_str,
is_group_mode,
Expand Down
8 changes: 7 additions & 1 deletion csrc/py_itfs_ck/mha_batch_prefill_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,13 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
has_lse,
qscale_type,
false);
TORCH_CHECK(t >= 0, "invalid argument for batch_prefill");
TORCH_CHECK(t >= 0,
"invalid argument for batch_prefill: no matching kernel found. "
"page_size=", args.page_block_size,
", num_pages=", args.num_total_pages,
", dtype=", dtype_str,
". If KV cache exceeds 2GB (INT32_MAX byte offset) with page_size < kN0, "
"CDNA3+ GPU (MI300/MI350) is required.");
}
else
{
Expand Down
Loading
Loading