Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
a.hdim_v,
a.data_type,
a.is_group_mode,
static_cast<mask_enum>(a.ck_mask_type),
static_cast<mask_enum>(a.mask_type),
static_cast<bias_enum>(a.bias_type),
a.has_dbias,
a.has_dropout,
Expand Down Expand Up @@ -220,7 +220,7 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
/* split_stride_dq_acc*/ a.split_stride_dq_acc,
/* window_size_left */ a.window_size_left,
/* window_size_right */ a.window_size_right,
/* mask_type */ a.ck_mask_type,
/* mask_type */ a.mask_type,
/* p_drop */ a.p_drop,
/* p_undrop */ a.p_undrop,
/* drop_seed_offset */ a.drop_seed_offset,
Expand Down Expand Up @@ -249,6 +249,41 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
return -1;
}

// ASM mask type
// 0: no mask
// 1: top-left triangular
// 2: bottom-right triangular
// 3: window mask
// -1: unsupported (e.g., ck generic mask)
auto asm_mask_type = [&]() {
if(a.mask_type == static_cast<ck_tile::index_t>(mask_enum::no_mask))
{
return 0;
}
else if(a.mask_type == static_cast<ck_tile::index_t>(mask_enum::window_generic))
{
// CK generic mask isn't supported here
return -1;
}
Comment thread
slippedJim marked this conversation as resolved.
else
{
if(a.window_size_left == -1 && a.window_size_right == 0)
{
// Note: this case includes both top-left and bottom-right masks, but they share the same
// kernel selection logic in bwd since the attention sink isn't supported in bwd yet
return (a.mask_type == static_cast<ck_tile::index_t>(mask_enum::mask_top_left)) ? 1 : 2;
}
else if(a.window_size_left == -1 && a.window_size_right == -1)
{
return 0;
}
else
{
return 3;
}
}
};

auto pre_cfgs = &cfg_fmha_bwd_odo;
auto dqdkdv_cfgs = &cfg_fmha_bwd_dqdkdv;
auto post_cfgs = [&]() {
Expand Down Expand Up @@ -279,13 +314,21 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
bool need_post_processing =
((arch_id == "gfx950") && (a.hdim_q != 64)) || (a.v3_atomic_fp32 == 1);

int mt = asm_mask_type();

if (mt == -1)
{
std::cout << "fmha_v3_bwd: unsupported mask type for asm kernels." << std::endl;
return -1;
}

auto [pre_kernel, dqdkdv_kernel, post_kernel] = get_heuristic_kernel(a.data_type,
arch_id,
a.seqlen_q,
a.seqlen_k,
a.hdim_q,
a.hdim_v,
a.mask_type,
mt,
a.v3_atomic_fp32,
a.v3_bf16_cvt,
a.is_group_mode,
Expand Down Expand Up @@ -465,7 +508,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
}
dqdkdv_args.max_seqlen_dq = a.v3_atomic_fp32 ? a.max_seqlen_q : (a.max_seqlen_q + 15) / 16 * 16;

if(a.mask_type == 3)
if(mt == 3)
{
// Note: sink_size=0 is passed as the 3rd parameter (attention sink not supported in bwd
// yet)
Expand All @@ -476,8 +519,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
sink_size,
a.seqlen_q,
a.seqlen_k,
(a.ck_mask_type == static_cast<ck_tile::index_t>(mask_enum::mask_top_left) ||
a.ck_mask_type == static_cast<ck_tile::index_t>(mask_enum::window_generic)));
(a.mask_type == static_cast<ck_tile::index_t>(mask_enum::mask_top_left) ||
a.mask_type == static_cast<ck_tile::index_t>(mask_enum::window_generic)));
dqdkdv_args.mask_y = generic_mask.at(ck_tile::number<0>{});
dqdkdv_args.mask_x = generic_mask.at(ck_tile::number<1>{});
}
Expand All @@ -489,7 +532,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
int gdy = a.nhead_q;
int gdz = a.batch;

if((a.mask_type == 1) || (a.mask_type == 2))
if((mt == 1) || (mt == 2))
{ // causal
gdx = (gdx + 1) / 2;
}
Expand Down
5 changes: 2 additions & 3 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

// Include these 2 headers instead of torch/extension.h since we don't need all of the torch
// headers.
Expand All @@ -13,7 +13,6 @@ namespace aiter {
struct mha_bwd_args
{
// aiter args
int mask_type; // 0: no mask 1: top_left_causal 2: bottom_right_causal 3: sliding_window
bool use_asm_v3;
bool v3_atomic_fp32;
int v3_bf16_cvt;
Expand All @@ -24,7 +23,7 @@ struct mha_bwd_args
int hdim_v;
std::string data_type;
bool is_group_mode;
int ck_mask_type;
int mask_type;
int bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
Expand Down
22 changes: 2 additions & 20 deletions csrc/py_itfs_ck/mha_bwd_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
Expand Down Expand Up @@ -91,23 +91,6 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}

auto get_mask_type = [&]() {
if (mask.type == mask_enum::no_mask) {
return 0;
} else {
if (mask.type == mask_enum::window_generic) {
assert(false);
return 0;
} else {
if ((mask.left == -1) && (mask.right == 0)) {
return (mask.type == mask_enum::mask_top_left) ? 1 : 2;
} else {
return 3;
}
}
}
};

// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_q);
Expand Down Expand Up @@ -302,8 +285,7 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
nhead_stride_dbias = dbias.stride(2);
}

return mha_bwd_args{get_mask_type(),
false, // use_v3
return mha_bwd_args{false, // use_v3
false, // is_v3_atomic_fp32
false, // how_v3_bf16_cvt
false, // v3_api_check
Expand Down
22 changes: 2 additions & 20 deletions csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
Expand Down Expand Up @@ -109,23 +109,6 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v]
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}

auto get_mask_type = [&]() {
if (mask.type == mask_enum::no_mask) {
return 0;
} else {
if (mask.type == mask_enum::window_generic) {
assert(false);
return 0;
} else {
if ((mask.left == -1) && (mask.right == 0)) {
return (mask.type == mask_enum::mask_top_left) ? 1 : 2;
} else {
return 3;
}
}
}
};

// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, total_q, num_heads, head_size_q);
Expand Down Expand Up @@ -311,8 +294,7 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v]
seqstart_q_ptr = cu_seqlens_q.data_ptr();
}

return mha_bwd_args{get_mask_type(),
false, // use_v3
return mha_bwd_args{false, // use_v3
false, // is_v3_atomic_fp32
false, // how_v3_bf16_cvt
false, // v3_api_check
Expand Down
22 changes: 2 additions & 20 deletions csrc/py_itfs_cu/asm_mha_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
Expand Down Expand Up @@ -89,23 +89,6 @@ std::vector<at::Tensor> fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
}

auto get_mask_type = [&]() {
if (mask.type == mask_enum::no_mask) {
return 0;
} else {
if (mask.type == mask_enum::window_generic) {
assert(false);
return 0;
} else {
if ((mask.left == -1) && (mask.right == 0)) {
return (mask.type == mask_enum::mask_top_left) ? 1 : 2;
} else {
return 3;
}
}
}
};

// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_q);
Expand Down Expand Up @@ -264,8 +247,7 @@ std::vector<at::Tensor> fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

return mha_bwd_args{get_mask_type(),
true,
return mha_bwd_args{true,
is_v3_atomic_fp32,
how_v3_bf16_cvt,
false,
Expand Down
22 changes: 2 additions & 20 deletions csrc/py_itfs_cu/asm_mha_varlen_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
Expand Down Expand Up @@ -109,23 +109,6 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}

auto get_mask_type = [&]() {
if (mask.type == mask_enum::no_mask) {
return 0;
} else {
if (mask.type == mask_enum::window_generic) {
assert(false);
return 0;
} else {
if ((mask.left == -1) && (mask.right == 0)) {
return (mask.type == mask_enum::mask_top_left) ? 1 : 2;
} else {
return 3;
}
}
}
};

// q, k, v, out had been padded in mha_fwd
// dq_, dk_, dv_ are also padded tensor
CHECK_SHAPE(q, total_q, num_heads, head_size_q);
Expand Down Expand Up @@ -325,8 +308,7 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v
seqstart_q_ptr = cu_seqlens_q.data_ptr();
}

return mha_bwd_args{get_mask_type(),
true, // use_v3
return mha_bwd_args{true, // use_v3
is_v3_atomic_fp32, // is_v3_atomic_fp32
how_v3_bf16_cvt, // how_v3_bf16_cvt
false, // v3_api_check
Expand Down
31 changes: 2 additions & 29 deletions op_tests/cpp/mha/benchmark_mha_bwd.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "mha_bwd.h"
#include "utils.hpp"
Expand Down Expand Up @@ -506,32 +506,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< " MByte memory workspace allocated" << std::endl;
}

auto get_mask_type = [&]() {
if(mask.type == mask_enum::no_mask)
{
return 0;
}
else
{
if(mask.type == mask_enum::window_generic)
{
assert(false);
return 0;
}
else
{
if((mask.left == -1) && (mask.right == 0))
{
return (mask.type == mask_enum::mask_top_left) ? 1 : 2;
}
else
{
return 3;
}
}
}
};

auto mha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
Expand Down Expand Up @@ -590,8 +564,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}();

return aiter::mha_bwd_args{get_mask_type(),
bwd_v3,
return aiter::mha_bwd_args{bwd_v3,
v3_atomic_fp32,
v3_bf16_cvt,
v3_api_check,
Expand Down
17 changes: 8 additions & 9 deletions op_tests/cpp/mha/smoke_test_bwd_v3.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#!/bin/sh
EXE="$(find . -name bwd.exe -type f | head -n 1)"
KNAME=1
Expand Down Expand Up @@ -59,8 +59,7 @@ run_swa_tests() {
for hdim in 72 96 128 ; do
for mask in "t:-1,10" "t:15,-1" "t:15,15" "t:190,187" "b:-1,10" "b:15,-1" "b:15,15" "b:190,187" ; do

$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=$seqlen_q -s_k=$seqlen_k -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=$seqlen_q -s_k=$seqlen_k -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=$seqlen_q -s_k=$seqlen_k -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=$seqlen_q -s_k=$seqlen_k -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -mode=0 -kname=$KNAME $COMMON_ARGS
Comment thread
slippedJim marked this conversation as resolved.

done
Expand Down Expand Up @@ -175,11 +174,11 @@ run_gfx950_hd192_128_bwd_v3() {
done
}

# run_batch_mode_tests
# run_group_mode_tests
# run_swa_tests
run_gfx950_group_bwd_v3
run_gfx950_bwd_v3
run_batch_mode_tests
run_group_mode_tests
run_swa_tests
# run_gfx950_group_bwd_v3
Comment thread
slippedJim marked this conversation as resolved.
# run_gfx950_bwd_v3

# hdim 192+128 tests
run_gfx950_hd192_128_bwd_v3
# run_gfx950_hd192_128_bwd_v3
Loading