Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ def cmdGenFunc_mha_batch_prefill(
is_causal: bool,
window_size_left: int,
window_size_right: int,
sink_size: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[Tensor] = None,
Expand Down Expand Up @@ -1046,6 +1047,17 @@ def cmdGenFunc_mha_batch_prefill(
# PERTENSOR: per-tensor quantization
md_name += "_pertensor"
filter_fwd += "_pertensor*"
# Sink only applies when there is a causal/window mask; full attention
# (window_size_left==-1 and window_size_right==-1) ignores sink_size.
has_effective_sink = sink_size > 0 and (
causal or not (window_size_left == -1 and window_size_right == -1)
)
if has_effective_sink:
md_name += "_sink"
filter_fwd += "_sink*"
else:
md_name += "_nsink"
filter_fwd += "_nsink*"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill "
"--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd)
Expand Down Expand Up @@ -2739,6 +2751,7 @@ def mha_batch_prefill_fake_tensors(
is_causal: bool,
window_size_left: int,
window_size_right: int,
sink_size: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -2823,6 +2836,7 @@ def mha_batch_prefill(
is_causal: bool,
window_size_left: int,
window_size_right: int,
sink_size: int,
return_softmax_lse: bool,
Comment thread
LJ-underdog marked this conversation as resolved.
return_dropout_randval: bool,
out: Optional[Tensor] = None,
Expand Down Expand Up @@ -2857,6 +2871,7 @@ def _mha_batch_prefill(
logits_soft_cap: float = 0.0,
window_size_left: int = -1,
window_size_right: int = -1,
sink_size: int = 0,
bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_lse: bool = False,
Expand Down Expand Up @@ -2892,6 +2907,7 @@ def _mha_batch_prefill(
causal,
window_size_left,
window_size_right,
sink_size,
return_lse,
return_softmax,
out,
Expand All @@ -2906,7 +2922,6 @@ def _mha_batch_prefill(
seqlen_k,
sink_ptr,
None,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -2938,6 +2953,7 @@ def mha_batch_prefill_func(
v_descale=None,
kv_block_descale=None, # [num_block, num_kv_head, 2] per-page K/V descales
sink_ptr=None,
sink_size: int = 0,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -2990,6 +3006,7 @@ def mha_batch_prefill_func(
logits_soft_cap=logits_soft_cap,
window_size_left=window_size[0],
window_size_right=window_size[1],
sink_size=sink_size,
alibi_slopes=alibi_slopes,
return_lse=return_lse,
return_softmax=return_attn_probs and dropout_p > 0,
Expand Down
11 changes: 8 additions & 3 deletions csrc/cpp_itfs/mha_fwd_batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ get_mha_batch_prefill_traits(int head_size_q,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout,
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table,
int page_size,
bool skip_min_seqlen_q = false)
bool skip_min_seqlen_q = false,
bool has_sink = false)
{
return mha_batch_prefill_traits(head_size_q,
head_size_v,
Expand All @@ -29,6 +30,7 @@ get_mha_batch_prefill_traits(int head_size_q,
has_dropout,
qscale_type,
skip_min_seqlen_q,
has_sink,
kv_memory_layout,
kv_lookup_table,
page_size);
Expand All @@ -47,13 +49,14 @@ float mha_batch_prefill(mha_batch_prefill_args args,
int head_size_q = args.hdim_q;
int head_size_v = args.hdim_v;
bool has_dropout = args.p_drop > 0.f;
bool has_sink = args.sink_size > 0;
Comment thread
LJ-underdog marked this conversation as resolved.

// The kUseGlobalLoad decision (>2GB KV cache → use `global_load_lds_*`
// instead of SRD `buffer_load_*`) is made per-arm inside the auto-generated
// dispatcher in fmha_batch_prefill_api.cpp, where each arm knows its own
// compile-time bn0 and dtype element size. The wrapper just forwards args;
// no runtime trait field for it.
auto traits = get_mha_batch_prefill_traits(head_size_q,
auto traits = get_mha_batch_prefill_traits(head_size_q,
head_size_v,
q_dtype_str,
is_group_mode,
Expand All @@ -65,7 +68,9 @@ float mha_batch_prefill(mha_batch_prefill_args args,
qscale_type,
args.kv_memory_layout,
args.kv_lookup_table,
args.page_block_size);
args.page_block_size,
/*skip_min_seqlen_q=*/false,
has_sink);
return fmha_batch_prefill(traits, args, stream_config);
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/include/mha_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct mha_batch_prefill_traits : public fmha_batch_prefill_traits
bool has_dropout,
quant_scale_enum qscale_type,
bool skip_min_seqlen_q,
bool has_sink,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout,
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table,
int page_size)
Expand All @@ -78,7 +79,7 @@ struct mha_batch_prefill_traits : public fmha_batch_prefill_traits
has_dropout,
qscale_type,
skip_min_seqlen_q,
false, // has_sink
has_sink,
kv_memory_layout,
kv_lookup_table,
page_size}
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ namespace py = pybind11;
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("sink_size"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("out") = std::nullopt, \
Expand Down
1 change: 1 addition & 0 deletions csrc/include/torch/mha_batch_prefill.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
bool is_causal,
int window_size_left,
int window_size_right,
int sink_size,
bool return_softmax_lse,
bool return_dropout_randval,
std::optional<at::Tensor> out_, // [total_q, hq, d]
Expand Down
13 changes: 9 additions & 4 deletions csrc/py_itfs_ck/mha_batch_prefill_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,17 @@ get_ck_fmha_batch_prefill_args(bool has_lse,
kv_last_page_lens_ptr = kv_last_page_lens.data_ptr();
}

fmha_batch_prefill_args args;
fmha_batch_prefill_args args{}; // zero-initialize all fields

args.q_ptr = q.data_ptr();
args.k_ptr = k.data_ptr();
args.v_ptr = v.data_ptr();
args.q_descale_ptr = q_descale.has_value() ? q_descale.value().data_ptr() : nullptr;
args.k_descale_ptr = k_descale.has_value() ? k_descale.value().data_ptr() : nullptr;
args.v_descale_ptr = v_descale.has_value() ? v_descale.value().data_ptr() : nullptr;
// sink_ptr is independent of sink_size: when provided, the kernel always reads
// it as per-head logit values for the virtual sink token (sink_value = *ptr / scale_s).
// When null, sink_value defaults to -inf (virtual token excluded from softmax).
Comment thread
LJ-underdog marked this conversation as resolved.
args.sink_ptr = sink_ptr_.has_value() ? sink_ptr_.value().data_ptr() : nullptr;
args.bias_ptr = bias_ptr;
args.rand_val_ptr = has_dropout_randval ? dropout_randval.data_ptr() : nullptr;
Expand Down Expand Up @@ -363,6 +366,7 @@ get_ck_fmha_batch_prefill_args(bool has_lse,
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.sink_size = mask.sink;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
args.p_drop = p_dropout;
args.s_randval = has_dropout_randval;
Expand Down Expand Up @@ -416,6 +420,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
bool is_causal,
int window_size_left,
int window_size_right,
int sink_size,
bool return_softmax_lse,
bool return_dropout_randval,
std::optional<at::Tensor> out_, // [total_q, hq, d]
Expand Down Expand Up @@ -609,18 +614,18 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
{
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0" + "," + std::to_string(sink_size);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
}
else if(window_size_left == -1 && window_size_right == -1)
{
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask; sink N/A for full attention
}
else
{
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify =
"b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
"b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right) + "," + std::to_string(sink_size);
mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
}

Expand Down
Loading
Loading