From b03e0c9b1bbaa79856471b442bfe3a74f0a2c8e1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 17 Mar 2025 13:41:34 -0700 Subject: [PATCH 1/7] slide window --- .../contrib_ops/cuda/bert/attention_impl.cu | 2 +- .../bert/cutlass_fmha/fmha_launch_template.h | 1 + .../cuda/bert/cutlass_fmha/kernel_forward.h | 67 +++++++++++++++++-- .../cutlass_fmha/memory_efficient_attention.h | 55 +++++++-------- .../cuda/bert/group_query_attention.cc | 7 +- .../cuda/bert/group_query_attention_impl.cu | 1 + .../bert/packed_multihead_attention_impl.cu | 4 +- .../transformers/test_flash_attn_cuda.py | 66 ++++++++++-------- 8 files changed, 131 insertions(+), 72 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 0209183f46425..84a7cc19f1576 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -512,7 +512,7 @@ Status EfficientAttention( p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; } else { - p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); + p.seqlen_k_ptr = reinterpret_cast(data.mask_index); p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size; p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1; } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 9b3ba73254d73..100ab0e0a2fdc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -222,6 +222,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { } p.use_smooth_softmax = params.use_smooth_softmax; + p.window_size = params.local_window_size; } auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index 8dff521da48d1..ca0b3a0fddfe6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -174,10 +174,9 @@ struct AttentionKernel { scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] - int32_t* seqstart_q_ptr = nullptr; - int32_t* seqstart_k_ptr = nullptr; - - int32_t* seqlen_k_ptr = nullptr; + const int32_t* seqstart_q_ptr = nullptr; + const int32_t* seqstart_k_ptr = nullptr; + const int32_t* seqlen_k_ptr = nullptr; uint32_t causal_diagonal_offset = 0; // Output tensors @@ -187,6 +186,8 @@ struct AttentionKernel { // [num_heads, num_queries] - can be null lse_scalar_t* logsumexp_ptr = nullptr; + int32_t window_size = -1; + // Scale accum_t scale = 0.0; @@ -651,6 +652,12 @@ struct AttentionKernel { XFORMERS_CHECK( p.custom_mask_type < NumCustomMaskTypes, "invalid value for `custom_mask_type`"); + if (p.window_size > 0) { + XFORMERS_CHECK( + p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight, + "invalid value for custom_mask_type"); + } return true; } @@ -726,6 +733,13 @@ struct AttentionKernel { // Iterate through keys for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; iter_key_start += kKeysPerBlock) { + if (p.window_size > 0) { + // don't compute anything if below attention band + if (iter_key_start + kKeysPerBlock < + int32_t(query_start + p.causal_diagonal_offset) - p.window_size) { + continue; + } + } int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); int32_t problem_size_0_n = cutlass::fast_min( @@ -894,6 +908,40 @@ struct AttentionKernel { }, [&](int accum_m) {}); } + + // Mask out lower left corner of block if window_size > 0 + // only required if current block intersects with the lower left corner + // block starts at x_lowerleft = iter_key_start // y = query_start + + // kQueriesPerBlock first non masked value at this y is : x_first = + // query_start + kQueriesPerBlock - window_size mask if x_fist > + // x_lowerleft + + if (p.window_size > 0 && + (query_start + p.causal_diagonal_offset + + cutlass::fast_min( + int32_t(kQueriesPerBlock), int32_t(p.num_queries)) - + p.window_size > + iter_key_start)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t first_col; + const int32_t offset = query_start + p.causal_diagonal_offset - + p.window_size - iter_key_start; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { first_col = accum_m + offset; }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < first_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + // print_warp_accum(accum, lane_offset, 12, + // 12); + } + // Update `mi` from accum stored in registers // Also does accum[i] <- exp(accum[i] - mi) iterative_softmax( @@ -1036,9 +1084,18 @@ struct AttentionKernel { } if (!kKeepOutputInRF) { + int first_key = 0; + if (p.window_size > 0) { + first_key = (cutlass::fast_max( + int(query_start + p.causal_diagonal_offset) - + p.window_size + 1, + 0) / + kKeysPerBlock) * + kKeysPerBlock; + } MM1::Mma::drain_cp_asyncs(); DISPATCH_BOOL( - iter_key_start == 0, kIsFirst, ([&] { + iter_key_start == first_key, kIsFirst, ([&] { DISPATCH_BOOL( (iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 9fe66c6fe992e..287413bf5acde 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,42 +14,39 @@ namespace cuda { constexpr int kEfficientAttentionMaxHeadSize = 1024; struct MemoryEfficientAttentionParams { - int32_t sm; - bool is_half; + int32_t sm = 50; + bool is_half = false; bool is_kv_bsnh = true; - int32_t batch_size; - int32_t num_heads; - int32_t sequence_length; - int32_t kv_sequence_length; - int32_t max_sequence_length; - int32_t qk_head_size; - int32_t v_head_size; - bool causal; - bool use_smooth_softmax; - - float scale; + int32_t batch_size = 0; + int32_t num_heads = 0; + int32_t sequence_length = 0; + int32_t kv_sequence_length = 0; + int32_t max_sequence_length = 0; + int32_t qk_head_size = 0; + int32_t v_head_size = 0; + int32_t local_window_size = -1; + bool causal = false; + bool use_smooth_softmax = false; + bool broadcast_attn_bias_dim_0 = false; + bool broadcast_attn_bias_dim_1 = false; + bool has_custom_right_padding = false; + float scale = 1.0f; float softcap = 0.0; - int32_t* seqstart_q_ptr; - int32_t* seqstart_k_ptr; - int32_t* seqlen_k_ptr; - - const void* query; // [B, S, N, H] - const void* key; // [B, L, N, H], where L is kv_sequence_length - const void* value; // [B, L, N, H_v] - const void* attn_bias; // [B or 1, N or 1, S, L] or null - bool broadcast_attn_bias_dim_0; - bool broadcast_attn_bias_dim_1; - - void* output; // [B, S, N, H_v] - void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise - cudaStream_t stream; + cudaStream_t stream = nullptr; + const int32_t* seqstart_q_ptr = nullptr; // [B + 1], cumulated sequence lengths of queries + const int32_t* seqstart_k_ptr = nullptr; // [B + 1], cumulated sequence lengths of keys + const int32_t* seqlen_k_ptr = nullptr; // [B], sequence lengths of keys + const void* query = nullptr; // [B, S, N, H] + const void* key = nullptr; // [B, L, N, H], where L is kv_sequence_length + const void* value = nullptr; // [B, L, N, H_v] + const void* attn_bias = nullptr; // [B or 1, N or 1, S, L] or null + void* workspace = nullptr; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + void* output = nullptr; // [B, S, N, H_v] static bool need_workspace(size_t v_head_size, bool is_float) { return (v_head_size > 128 && !is_float); } - - bool has_custom_right_padding = false; }; void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 8b63b363d8863..9f1bc46ee297d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -156,13 +156,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && - local_window_size_ == -1 && - (sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); - if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Local attention UNSUPPORTED for sm < 80 on CUDA."); - } + // allocate buffers size_t kv_buffer_bytes = 0; // need a buffer if we must ungroup kv diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index dbbee87238d0c..2d1b49033003d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -635,6 +635,7 @@ Status EfficientAttention( p.stream = stream; p.has_custom_right_padding = true; p.use_smooth_softmax = parameters.use_smooth_softmax; + p.local_window_size = parameters.local_window_size; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index f3b9fd310f46f..846d2be7bf2e1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -698,8 +698,8 @@ Status FusedAttentionCutlass( p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; - p.seqstart_q_ptr = const_cast(data.cumulative_sequence_length); - p.seqstart_k_ptr = const_cast(data.cumulative_sequence_length); + p.seqstart_q_ptr = data.cumulative_sequence_length; + p.seqstart_k_ptr = data.cumulative_sequence_length; p.query = data.no_qkv_workspace ? data.query : data.workspace; p.key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); p.value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index a74d5389e9047..0cf68cb191385 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -2103,18 +2103,20 @@ def gqa_no_past_memory_efficient_test_cases(): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - softcap, - ) + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) def gqa_no_past_flash_attention_test_cases(): @@ -2146,7 +2148,7 @@ def gqa_no_past_flash_attention_test_cases(): for softcap in [0.0, 50.0]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", config, local, rotary, @@ -2183,19 +2185,21 @@ def gqa_past_memory_efficient_test_cases(): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - softcap, - ) + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", + config, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) def gqa_past_flash_attention_test_cases(): @@ -2232,7 +2236,7 @@ def gqa_past_flash_attention_test_cases(): sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", config, local, rotary, @@ -2410,12 +2414,13 @@ def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotar @unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") class TestMemoryEfficientGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + def test_gqa_no_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") parity_check_gqa_prompt( config, + local=local, rtol=5e-3, atol=5e-3, past_format=Formats.BNSH, @@ -2427,6 +2432,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave ) parity_check_gqa_prompt_no_buff( config, + local=local, rtol=5e-3, atol=5e-3, past_format=Formats.BNSH, @@ -2438,12 +2444,13 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave ) @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): + def test_gqa_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") parity_check_gqa_past( config, + local=local, past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, @@ -2455,6 +2462,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, ) parity_check_gqa_past_no_buff( config, + local=local, past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, From 1c2825b66a2fbd23d09f40816cbbdd4241cdc58b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 17 Mar 2025 15:30:31 -0700 Subject: [PATCH 2/7] exclude some tests --- .../python/transformers/test_flash_attn_cuda.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 0cf68cb191385..40e34e51b44ab 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -2108,6 +2108,8 @@ def gqa_no_past_memory_efficient_test_cases(): for packed in [False, True]: for softcap in [0.0, 50.0]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + if rotary and h % 16 > 0: + continue yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", config, @@ -2146,6 +2148,9 @@ def gqa_no_past_flash_attention_test_cases(): for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", @@ -2189,6 +2194,9 @@ def gqa_past_memory_efficient_test_cases(): for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) yield ( @@ -2233,6 +2241,9 @@ def gqa_past_flash_attention_test_cases(): for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) yield ( @@ -2276,6 +2287,9 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): for local in [False, True]: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: + if rotary and h % 16 > 0: + continue + config = Config(b, s, s2, -1, n, n2, h) yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", @@ -2316,6 +2330,9 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): for h in h_sizes: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: + if rotary and h % 16 > 0: + continue + config = Config(b, s, s2, -1, n, n2, h) yield ( str(config) + f"{rotary}_{rotary_interleaved}_{packed}", From 31656bfb7baf2a03076da94fe3fe8810e050b118 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 17 Mar 2025 17:17:43 -0700 Subject: [PATCH 3/7] rename python test files --- ...st_flash_attn_cuda.py => test_gqa_cuda.py} | 494 +---------------- ...st_flash_attn_rocm.py => test_gqa_rocm.py} | 3 +- .../transformers/test_mha_flash_attn.py | 510 ++++++++++++++++++ 3 files changed, 523 insertions(+), 484 deletions(-) rename onnxruntime/test/python/transformers/{test_flash_attn_cuda.py => test_gqa_cuda.py} (81%) rename onnxruntime/test/python/transformers/{test_flash_attn_rocm.py => test_gqa_rocm.py} (98%) create mode 100644 onnxruntime/test/python/transformers/test_mha_flash_attn.py diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py similarity index 81% rename from onnxruntime/test/python/transformers/test_flash_attn_cuda.py rename to onnxruntime/test/python/transformers/test_gqa_cuda.py index 40e34e51b44ab..2936432e72541 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -17,7 +17,6 @@ import numpy import torch -from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper from packaging import version @@ -39,20 +38,18 @@ class Formats: class Config: batch_size = 0 sequence_length = 0 - kv_sequence_length = 0 - past_sequence_length = 0 + kv_sequence_length = 0 # this is past sequence length when there is past state. num_heads = 0 kv_num_heads = 0 head_size = 0 ep = "CUDAExecutionProvider" def __init__( - self, batch_size, sequence_length, kv_sequence_length, past_sequence_length, num_heads, kv_num_heads, head_size + self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size ): self.batch_size = batch_size self.sequence_length = sequence_length self.kv_sequence_length = kv_sequence_length - self.past_sequence_length = past_sequence_length self.num_heads = num_heads self.kv_num_heads = kv_num_heads self.head_size = head_size @@ -61,7 +58,7 @@ def __repr__(self): short_ep = self.ep[: -len("ExecutionProvider")].lower() return ( f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " - f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, " + f"kv_sequence_length={self.kv_sequence_length}, " f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" ) @@ -103,118 +100,6 @@ def __repr__(self): ) -def create_packed_multihead_attention_graph(config): - nodes = [ - helper.make_node( - "PackedMultiHeadAttention", - [ - "query", - "", - "", - "", - "token_offset", - "cumulative_sequence_length", - ], - ["output"], - "PackedMultiHeadAttention_0", - num_heads=config.num_heads, - domain="com.microsoft", - ), - ] - - graph = helper.make_graph( - nodes, - "PackedMultiHeadAttention_Graph", - [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - -1, - config.num_heads, - 3, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "token_offset", TensorProto.INT32, [config.batch_size, config.sequence_length] - ), - helper.make_tensor_value_info("cumulative_sequence_length", TensorProto.INT32, [config.batch_size + 1]), - ], - [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [-1, config.num_heads * config.head_size], - ), - ], - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def create_multihead_attention_graph(config): - nodes = [ - helper.make_node( - "MultiHeadAttention", - [ - "query", - "key", - "value", - ], - ["output"], - "MultiHeadAttention_0", - num_heads=config.num_heads, - domain="com.microsoft", - ), - ] - - graph = helper.make_graph( - nodes, - "MultiHeadAttention_Graph", - [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.num_heads * config.head_size, - ], - ), - ], - [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], - ), - ], - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - def create_group_query_attention_graph_prompt( config, past_kv_format=Formats.BSNH, @@ -575,204 +460,6 @@ def create_group_query_attention_graph_past( return model.SerializeToString() -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): - assert mode in ["full", "random", "third"] - if mode == "full": - lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) - elif mode == "random": - lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) - else: - lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) - padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths - return padding_mask - - -def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - - def output_pad_fn(output_unpad): - return pad_input(output_unpad, indices_q, batch_size, seqlen_q) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device - ) - max_seqlen_q = seqlen_q - - def output_pad_fn(output_unpad): - return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) - - if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - if query_padding_mask is not None: - - def dqkv_pad_fn(dqkv_unpad): - return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) - - else: - - def dqkv_pad_fn(dqkv_unpad): - return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) - - return ( - qkv_unpad.detach().requires_grad_(), - cu_seqlens_q, - max_seqlen_q, - qkv.detach().requires_grad_(), - output_pad_fn, - dqkv_pad_fn, - ) - elif kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - - def dkv_pad_fn(dkv_unpad): - return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) - - else: - - def dkv_pad_fn(dkv_unpad): - return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) - - return ( - q_unpad.detach().requires_grad_(), - kv_unpad.detach().requires_grad_(), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q.detach().requires_grad_(), - kv.detach().requires_grad_(), - output_pad_fn, - dq_pad_fn, - dkv_pad_fn, - ) - else: - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - - def dk_pad_fn(dk_unpad): - return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) - - else: - - def dk_pad_fn(dk_unpad): - return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) - - return ( - q_unpad.detach().requires_grad_(), - k_unpad.detach().requires_grad_(), - v_unpad.detach().requires_grad_(), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q.detach().requires_grad_(), - k.detach().requires_grad_(), - v.detach().requires_grad_(), - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) - - -def create_inputs(config: Config, kv_packed=False, qkv_packed=True): - qkv = torch.randn( - config.batch_size, - config.sequence_length, - 3, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - key_padding_mask = generate_random_padding_mask( - config.sequence_length, config.batch_size, device="cuda", mode="random" - ) - qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed - ) - return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask - - -def generate_token_offset(cu_seqlens, max_seqlen): - token_offset = [] - token_padset = [] # These are the indices that contain padding tokens - for i in range(1, len(cu_seqlens)): - start = i - 1 - pre_seqlen = cu_seqlens[i - 1] - seqlen = cu_seqlens[i] - token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen)) - token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen) - return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) - - -def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): - onnx_model_str = create_packed_multihead_attention_graph(config) - qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) - ort_inputs = { - "query": qkv_unpad.detach().cpu().numpy(), - "token_offset": token_offset, - "cumulative_sequence_length": cu_seqlens.cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - ort_output = ort_session.run(None, ort_inputs) - output = torch.tensor(ort_output) - return output - - -def mha_func(q, k, v, config): - onnx_model_str = create_multihead_attention_graph(config) - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) - v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) - ort_inputs = { - "query": q.detach().cpu().numpy(), - "key": k.detach().cpu().numpy(), - "value": v.detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - ort_output = ort_session.run(None, ort_inputs) - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output - - def rotary_options_for_current_os(): # Reference implementation of rotary uses triton, which is not available in Windows. # So we only test rotary in Linux right now. @@ -1009,14 +696,6 @@ def gqa_past_func( return output, present_k, present_v -def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - return col_idx > row_idx + sk - sq - - def construct_local_mask( seqlen_q, seqlen_k, @@ -1127,93 +806,6 @@ def attention_ref( return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) -def attention_qkvpacked_ref( - qkv, - key_padding_mask=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - upcast=True, - reorder_ops=False, - use_smooth_softmax=False, -): - return attention_ref( - qkv[:, :, 0], - qkv[:, :, 1], - qkv[:, :, 2], - key_padding_mask, - key_padding_mask, - dropout_p, - dropout_mask, - upcast=upcast, - causal=causal, - reorder_ops=reorder_ops, - use_smooth_softmax=use_smooth_softmax, - ) - - -def parity_check_mha( - config, - packed, - rtol=1e-3, - atol=1e-3, -): - if packed: - qkv_unpad, cu_seqlens, _, qkv, output_pad_fn, _, key_padding_mask = create_inputs(config) - token_offset = generate_token_offset(cu_seqlens, config.sequence_length).reshape( - (config.batch_size, config.sequence_length) - ) - # ORT Flash - out_unpad = flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False) - out_unpad = torch.squeeze(out_unpad, 0) - out = torch.reshape( - output_pad_fn(out_unpad), (config.batch_size, config.sequence_length, config.num_heads, config.head_size) - ) - out = out.detach().cpu().numpy() - # Pytorch to compare - out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) - out_ref = out_ref.detach().cpu().numpy() - else: - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - out = mha_func(q, k, v, config) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) - out_ref = out_ref.detach().cpu().numpy() - - numpy.testing.assert_allclose( - out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=f" with {config} packed={packed}" - ) - - def rotary_embedding(*args, **kwargs): # Use local import since triton is not available in Windows. from rotary_flash import apply_rotary_emb @@ -1222,7 +814,7 @@ def rotary_embedding(*args, **kwargs): def parity_check_gqa_prompt( - config, + config: PromptConfig, causal=True, local=False, past_format=Formats.BNSH, @@ -1420,7 +1012,7 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( - config, + config: PromptConfig, causal=True, local=False, past_format=Formats.BNSH, @@ -1595,7 +1187,7 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( - config, + config: Config, causal=True, local=False, past_format=Formats.BNSH, @@ -1788,7 +1380,7 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( - config, + config: Config, causal=True, local=False, past_format=Formats.BNSH, @@ -2019,67 +1611,6 @@ def has_memory_efficient(): return True -def packed_mha_test_cases(): - batches = [2] if pipeline_mode else [1, 5] - seqs = [1024, 1025] if pipeline_mode else [1024, 1025, 2048] - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - for b in batches: - for s in seqs: - for n in num_h: - for h in h_sizes: - config = Config(b, s, s, 0, n, n, h) - yield str(config), config - - -def mha_test_cases(): - batches = [2] if pipeline_mode else [1, 5] - seqs = ( - [(1, 128), (113, 211), (2048, 2048)] - if pipeline_mode - else [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ] - ) - num_h = [3] if pipeline_mode else [1, 6, 16] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - for b in batches: - for s, s2 in seqs: - for n in num_h: - for h in h_sizes: - config = Config(b, s, s2, 0, n, n, h) - yield str(config), config - - -class TestMHA(unittest.TestCase): - @parameterized.expand(packed_mha_test_cases()) - def test_packed_mha(self, _, config): - if not has_flash_attention(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - print("-------- TEST PACKED MHA ---------") - parity_check_mha(config, True) - - @parameterized.expand(mha_test_cases()) - def test_mha(self, _, config): - if not has_flash_attention(): - return - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - print("-------- TEST MHA ---------") - parity_check_mha(config, False) - - def gqa_no_past_memory_efficient_test_cases(): batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( @@ -2196,9 +1727,7 @@ def gqa_past_memory_efficient_test_cases(): for softcap in [0.0, 50.0]: if rotary and h % 16 > 0: continue - - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) + config = Config(b, s, s2, n, n2, h) yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", config, @@ -2244,8 +1773,7 @@ def gqa_past_flash_attention_test_cases(): if rotary and h % 16 > 0: continue - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) + config = Config(b, s, s2, n, n2, h) yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", config, @@ -2290,7 +1818,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases(): if rotary and h % 16 > 0: continue - config = Config(b, s, s2, -1, n, n2, h) + config = Config(b, s, s2, n, n2, h) yield ( str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", config, @@ -2333,7 +1861,7 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): if rotary and h % 16 > 0: continue - config = Config(b, s, s2, -1, n, n2, h) + config = Config(b, s, s2, n, n2, h) yield ( str(config) + f"{rotary}_{rotary_interleaved}_{packed}", config, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_gqa_rocm.py similarity index 98% rename from onnxruntime/test/python/transformers/test_flash_attn_rocm.py rename to onnxruntime/test/python/transformers/test_gqa_rocm.py index a5910c28c2975..29ae1b6e44a78 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_gqa_rocm.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from test_flash_attn_cuda import ( +from test_gqa_cuda import ( Formats, gqa_no_past_flash_attention_test_cases, gqa_past_flash_attention_test_cases, @@ -38,6 +38,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rtol=0.001, atol=0.005, ) + parity_check_gqa_prompt_no_buff( config, local=local, diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py new file mode 100644 index 0000000000000..cf73a7b810942 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -0,0 +1,510 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import unittest +import os +from parameterized import parameterized +import numpy +import torch +from bert_padding import pad_input, unpad_input +from einops import rearrange, repeat +from onnx import TensorProto, helper +from test_gqa_cuda import attention_ref, has_flash_attention +from onnxruntime import InferenceSession, SessionOptions + +torch.manual_seed(0) + +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + +class Formats: + BSNH = 0 + BNSH = 1 + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 # this is past sequence length when there is past state. + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + ep = "CUDAExecutionProvider" + + def __init__( + self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.kv_sequence_length = kv_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size + + def __repr__(self): + short_ep = self.ep[: -len("ExecutionProvider")].lower() + return ( + f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " + f"kv_sequence_length={self.kv_sequence_length}, " + f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" + ) + + +def create_packed_multihead_attention_graph(config:Config): + nodes = [ + helper.make_node( + "PackedMultiHeadAttention", + [ + "query", + "", + "", + "", + "token_offset", + "cumulative_sequence_length", + ], + ["output"], + "PackedMultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "PackedMultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + -1, + config.num_heads, + 3, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "token_offset", TensorProto.INT32, [config.batch_size, config.sequence_length] + ), + helper.make_tensor_value_info("cumulative_sequence_length", TensorProto.INT32, [config.batch_size + 1]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [-1, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_multihead_attention_graph(config:Config): + nodes = [ + helper.make_node( + "MultiHeadAttention", + [ + "query", + "key", + "value", + ], + ["output"], + "MultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "MultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + + def output_pad_fn(output_unpad): + return pad_input(output_unpad, indices_q, batch_size, seqlen_q) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + + def output_pad_fn(output_unpad): + return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + + def dqkv_pad_fn(dqkv_unpad): + return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + + else: + + def dqkv_pad_fn(dqkv_unpad): + return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dkv_pad_fn(dkv_unpad): + return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dkv_pad_fn(dkv_unpad): + return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dk_pad_fn(dk_unpad): + return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dk_pad_fn(dk_unpad): + return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def create_inputs(config: Config, kv_packed=False, qkv_packed=True): + qkv = torch.randn( + config.batch_size, + config.sequence_length, + 3, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + key_padding_mask = generate_random_padding_mask( + config.sequence_length, config.batch_size, device="cuda", mode="random" + ) + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed + ) + return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask + + +def generate_token_offset(cu_seqlens, max_seqlen): + token_offset = [] + token_padset = [] # These are the indices that contain padding tokens + for i in range(1, len(cu_seqlens)): + start = i - 1 + pre_seqlen = cu_seqlens[i - 1] + seqlen = cu_seqlens[i] + token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen)) + token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen) + return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) + + +def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config): + onnx_model_str = create_packed_multihead_attention_graph(config) + qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) + ort_inputs = { + "query": qkv_unpad.detach().cpu().numpy(), + "token_offset": token_offset, + "cumulative_sequence_length": cu_seqlens.cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) + ort_output = ort_session.run(None, ort_inputs) + output = torch.tensor(ort_output) + return output + + +def mha_func(q, k, v, config): + onnx_model_str = create_multihead_attention_graph(config) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) + ort_output = ort_session.run(None, ort_inputs) + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output + + +def attention_qkvpacked_ref( + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, + ) + + +def parity_check_mha( + config, + packed, + rtol=1e-3, + atol=1e-3, +): + if packed: + qkv_unpad, cu_seqlens, _, qkv, output_pad_fn, _, key_padding_mask = create_inputs(config) + token_offset = generate_token_offset(cu_seqlens, config.sequence_length).reshape( + (config.batch_size, config.sequence_length) + ) + # ORT Flash + out_unpad = flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config) + out_unpad = torch.squeeze(out_unpad, 0) + out = torch.reshape( + output_pad_fn(out_unpad), (config.batch_size, config.sequence_length, config.num_heads, config.head_size) + ) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) + out_ref = out_ref.detach().cpu().numpy() + else: + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + out = mha_func(q, k, v, config) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) + out_ref = out_ref.detach().cpu().numpy() + + numpy.testing.assert_allclose( + out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=f" with {config} packed={packed}" + ) + + +def packed_mha_test_cases(): + batches = [2] if pipeline_mode else [1, 5] + seqs = [1024, 1025] if pipeline_mode else [1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, n, n, h) + yield str(config), config + + +def mha_test_cases(): + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [3] if pipeline_mode else [1, 6, 16] + h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, n, n, h) + yield str(config), config + + +class TestMHA(unittest.TestCase): + @parameterized.expand(packed_mha_test_cases()) + def test_packed_mha(self, _, config): + if not has_flash_attention(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + print("-------- TEST PACKED MHA ---------") + parity_check_mha(config, True) + + @parameterized.expand(mha_test_cases()) + def test_mha(self, _, config): + if not has_flash_attention(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + print("-------- TEST MHA ---------") + parity_check_mha(config, False) + +if __name__ == "__main__": + unittest.main() From 25ee73ff37335de89f2438470b353efc5eae9bf2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 18 Mar 2025 22:14:08 +0000 Subject: [PATCH 4/7] format --- .../test/python/transformers/test_gqa_cuda.py | 4 +--- .../python/transformers/test_mha_flash_attn.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 2936432e72541..3923b229a0bff 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -44,9 +44,7 @@ class Config: head_size = 0 ep = "CUDAExecutionProvider" - def __init__( - self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size - ): + def __init__(self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size): self.batch_size = batch_size self.sequence_length = sequence_length self.kv_sequence_length = kv_sequence_length diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index cf73a7b810942..cdc0bcb04dba7 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -3,25 +3,29 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- -import unittest import os -from parameterized import parameterized +import unittest + import numpy import torch from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from parameterized import parameterized from test_gqa_cuda import attention_ref, has_flash_attention + from onnxruntime import InferenceSession, SessionOptions torch.manual_seed(0) pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + class Formats: BSNH = 0 BNSH = 1 + class Config: batch_size = 0 sequence_length = 0 @@ -31,9 +35,7 @@ class Config: head_size = 0 ep = "CUDAExecutionProvider" - def __init__( - self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size - ): + def __init__(self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size): self.batch_size = batch_size self.sequence_length = sequence_length self.kv_sequence_length = kv_sequence_length @@ -50,7 +52,7 @@ def __repr__(self): ) -def create_packed_multihead_attention_graph(config:Config): +def create_packed_multihead_attention_graph(config: Config): nodes = [ helper.make_node( "PackedMultiHeadAttention", @@ -101,7 +103,7 @@ def create_packed_multihead_attention_graph(config:Config): return model.SerializeToString() -def create_multihead_attention_graph(config:Config): +def create_multihead_attention_graph(config: Config): nodes = [ helper.make_node( "MultiHeadAttention", @@ -161,6 +163,7 @@ def create_multihead_attention_graph(config:Config): model = helper.make_model(graph) return model.SerializeToString() + def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": @@ -506,5 +509,6 @@ def test_mha(self, _, config): print("-------- TEST MHA ---------") parity_check_mha(config, False) + if __name__ == "__main__": unittest.main() From eab2202e609ecba2e8c16bae28a28c61a0293bb1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 18 Mar 2025 16:04:44 -0700 Subject: [PATCH 5/7] clean up --- .../transformers/test_mha_flash_attn.py | 148 +++++------------- 1 file changed, 43 insertions(+), 105 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index cdc0bcb04dba7..f87370e37d21a 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -176,7 +176,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): return padding_mask -def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): +def generate_packed_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None): """ Arguments: q: (batch_size, seqlen_q, nheads, d) @@ -185,7 +185,6 @@ def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpack query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ - assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) @@ -208,96 +207,37 @@ def output_pad_fn(output_unpad): return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + k_unpad, _, _, _ = unpad_input(k, key_padding_mask) v_unpad, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - if query_padding_mask is not None: - - def dqkv_pad_fn(dqkv_unpad): - return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) - - else: - - def dqkv_pad_fn(dqkv_unpad): - return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) - - return ( - qkv_unpad.detach().requires_grad_(), - cu_seqlens_q, - max_seqlen_q, - qkv.detach().requires_grad_(), - output_pad_fn, - dqkv_pad_fn, - ) - elif kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - - def dkv_pad_fn(dkv_unpad): - return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) - else: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: - def dkv_pad_fn(dkv_unpad): - return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + def dqkv_pad_fn(dqkv_unpad): + return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) - return ( - q_unpad.detach().requires_grad_(), - kv_unpad.detach().requires_grad_(), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q.detach().requires_grad_(), - kv.detach().requires_grad_(), - output_pad_fn, - dq_pad_fn, - dkv_pad_fn, - ) else: - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - - def dk_pad_fn(dk_unpad): - return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) - - else: - def dk_pad_fn(dk_unpad): - return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + def dqkv_pad_fn(dqkv_unpad): + return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) - return ( - q_unpad.detach().requires_grad_(), - k_unpad.detach().requires_grad_(), - v_unpad.detach().requires_grad_(), - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q.detach().requires_grad_(), - k.detach().requires_grad_(), - v.detach().requires_grad_(), - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) -def create_inputs(config: Config, kv_packed=False, qkv_packed=True): +def create_inputs(config: Config): qkv = torch.randn( config.batch_size, config.sequence_length, @@ -308,13 +248,11 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): dtype=torch.float16, requires_grad=False, ) - key_padding_mask = generate_random_padding_mask( - config.sequence_length, config.batch_size, device="cuda", mode="random" - ) - qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed + padding_mask = generate_random_padding_mask(config.sequence_length, config.batch_size, device="cuda", mode="random") + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_packed_qkv( + *qkv.unbind(dim=2), padding_mask, padding_mask ) - return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask + return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, padding_mask def generate_token_offset(cu_seqlens, max_seqlen): @@ -450,22 +388,22 @@ def parity_check_mha( def packed_mha_test_cases(): - batches = [2] if pipeline_mode else [1, 5] - seqs = [1024, 1025] if pipeline_mode else [1024, 1025, 2048] - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - for b in batches: - for s in seqs: - for n in num_h: - for h in h_sizes: + batch_sizes = [2] if pipeline_mode else [1, 5] + sequence_lengths = [1024, 1025] if pipeline_mode else [1024, 1025, 2048] + num_heads = [1, 3] if pipeline_mode else [1, 6, 16] + head_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + + for b in batch_sizes: + for s in sequence_lengths: + for n in num_heads: + for h in head_sizes: config = Config(b, s, s, n, n, h) yield str(config), config def mha_test_cases(): - batches = [2] if pipeline_mode else [1, 5] - seqs = ( + batch_sizes = [2] if pipeline_mode else [1, 5] + sequence_lengths = ( [(1, 128), (113, 211), (2048, 2048)] if pipeline_mode else [ @@ -481,14 +419,14 @@ def mha_test_cases(): (2048, 2048), ] ) - num_h = [3] if pipeline_mode else [1, 6, 16] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - for b in batches: - for s, s2 in seqs: - for n in num_h: - for h in h_sizes: - config = Config(b, s, s2, n, n, h) + num_heads = [3] if pipeline_mode else [1, 6, 16] + head_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + + for b in batch_sizes: + for s, kv_sequence_length in sequence_lengths: + for n in num_heads: + for h in head_sizes: + config = Config(b, s, kv_sequence_length, n, n, h) yield str(config), config From 7c9ce52270b31507c8b7fe93ebc499aba11d7ff1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Mar 2025 08:21:37 +0000 Subject: [PATCH 6/7] Fix local_window_size=0 --- onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | 6 +++--- onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | 3 ++- .../cuda/bert/cutlass_fmha/fmha_launch_template.h | 4 +++- .../contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 417865bb166ec..c3d5128948c6f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -87,9 +87,9 @@ struct GroupQueryAttentionParameters : AttentionParameters { int seqlen_present_kv_cache; // sequence length of present kv tensor int kv_hidden_size; int kv_num_heads; - int num_splits; // number of splits for splitkv - int rotary_dim; // rotary embedding dimension - int local_window_size; + int num_splits; // number of splits for splitkv + int rotary_dim; // rotary embedding dimension + int local_window_size; // The window size excludes current token. It only includes tokens on the left side. bool kv_share_buffer; bool is_packed_qkv; bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index c8c66c880852f..c79508cbae273 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -270,7 +270,8 @@ class GQAAttentionBase { for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; - const bool should_apply_local_window = local_window_size_ > 0 && + // local_window_size does not include the current query token, while window_size includes it. + const bool should_apply_local_window = local_window_size_ >= 0 && seq_causal_length > static_cast(local_window_size_) + 1; const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 100ab0e0a2fdc..8d8f735e3ed34 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -222,7 +222,9 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { } p.use_smooth_softmax = params.use_smooth_softmax; - p.window_size = params.local_window_size; + + // local_windows_size in GQA does not include current query token, while windows_size in this kernel includes it. + p.window_size = params.local_window_size + 1; } auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index ca0b3a0fddfe6..4b2a527f148ac 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -920,7 +920,7 @@ struct AttentionKernel { (query_start + p.causal_diagonal_offset + cutlass::fast_min( int32_t(kQueriesPerBlock), int32_t(p.num_queries)) - - p.window_size > + p.window_size >= iter_key_start)) { auto query_start = blockIdx.x * kQueriesPerBlock; auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( @@ -932,7 +932,7 @@ struct AttentionKernel { lane_offset, [&](int accum_m) { first_col = accum_m + offset; }, [&](int accum_m, int accum_n, int idx) { - if (accum_n < first_col) { + if (accum_n <= first_col) { accum[idx] = -cutlass::platform::numeric_limits::infinity(); } From dbce7241ccfac1eeaec7796df55d0a3c03aa5bce Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 19 Mar 2025 18:31:55 +0000 Subject: [PATCH 7/7] review feedback --- .../contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h index 4b2a527f148ac..f35d6c2e6c8dc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -736,7 +736,7 @@ struct AttentionKernel { if (p.window_size > 0) { // don't compute anything if below attention band if (iter_key_start + kKeysPerBlock < - int32_t(query_start + p.causal_diagonal_offset) - p.window_size) { + static_cast(query_start + p.causal_diagonal_offset) - p.window_size) { continue; } } @@ -919,7 +919,7 @@ struct AttentionKernel { if (p.window_size > 0 && (query_start + p.causal_diagonal_offset + cutlass::fast_min( - int32_t(kQueriesPerBlock), int32_t(p.num_queries)) - + static_cast(kQueriesPerBlock), static_cast(p.num_queries)) - p.window_size >= iter_key_start)) { auto query_start = blockIdx.x * kQueriesPerBlock; @@ -938,8 +938,6 @@ struct AttentionKernel { } }, [&](int accum_m) {}); - // print_warp_accum(accum, lane_offset, 12, - // 12); } // Update `mi` from accum stored in registers @@ -1087,7 +1085,7 @@ struct AttentionKernel { int first_key = 0; if (p.window_size > 0) { first_key = (cutlass::fast_max( - int(query_start + p.causal_diagonal_offset) - + static_cast(query_start + p.causal_diagonal_offset) - p.window_size + 1, 0) / kKeysPerBlock) *