@@ -3815,124 +3815,126 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
38153815 combinations = product ([False , True ], \
38163816 [InputLayout .PACKED_QKV , InputLayout .CONTIGUOUS_Q_KV ,
38173817 InputLayout .Q_PAGED_KV , InputLayout .SEPARATE_Q_K_V ],
3818- [False , True ])
3819- for (alibi , input_layout , enable_attn_logit_softcapping ) in combinations :
3818+ [False , True ], [False , True ])
3819+ for (alibi , input_layout , enable_attn_logit_softcapping ,
3820+ return_softmax ) in combinations :
38203821 # alibi and bmm1_tanh_scale shouldn't be used together.
38213822 if alibi and enable_attn_logit_softcapping :
38223823 continue
3823- # D <= 64: KV_STEP = 256
3824- specs .append (
3825- kernel_spec (
3826- sm = sm ,
3827- sm_mma = 90 ,
3828- dtype = dtype ,
3829- seq_len = 0 , # support any sequence length
3830- head_size = [32 , 40 , 48 , 64 ],
3831- warps_m = 4 , #4x1 warpgroups
3832- warps_n = 1 ,
3833- version = 2 ,
3834- interleaved = False ,
3835- ldgsts_q =
3836- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3837- ldgsts_k = False ,
3838- ldgsts_v = False ,
3839- share_smem_k_v = False ,
3840- loop_step = 64 ,
3841- q_tile_buffers = 1 , # only used by warp specialized kernels
3842- has_noloop = 0 ,
3843- noloop_step = 64 ,
3844- kv_loop_step = 256 ,
3845- kv_tile_buffers = 4 , # only used by warp specialized kernels
3846- unroll_threshold = 1 ,
3847- has_scale_max = False ,
3848- flash_attention = True ,
3849- warp_specialization = True ,
3850- alibi = alibi ,
3851- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3852- return_softmax_stats =
3853- False , # return softmax stats is not supported for fp8 now
3854- scheduling_mode = scheduling_mode ,
3855- input_layout = input_layout ,
3856- sage_block_sizes = sage_block_sizes ,
3857- output_dtype = output_dtype ))
3858-
3859- # 64 < D <=128: KV_STEP = 128
3860- specs .append (
3861- kernel_spec (
3862- sm = sm ,
3863- sm_mma = 90 ,
3864- dtype = dtype ,
3865- seq_len = 0 , # support any sequence length
3866- head_size = [80 , 96 , 104 , 128 ],
3867- warps_m = 4 , #4x1 warpgroups
3868- warps_n = 1 ,
3869- version = 2 ,
3870- interleaved = False ,
3871- ldgsts_q =
3872- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3873- ldgsts_k = False ,
3874- ldgsts_v = False ,
3875- share_smem_k_v = False ,
3876- loop_step = 64 ,
3877- q_tile_buffers = 1 , # only used by warp specialized kernels
3878- has_noloop = 0 ,
3879- noloop_step = 64 ,
3880- kv_loop_step = 256 ,
3881- kv_tile_buffers = 2 , # only used by warp specialized kernels
3882- unroll_threshold = 1 ,
3883- has_scale_max = False ,
3884- flash_attention = True ,
3885- warp_specialization = True ,
3886- alibi = alibi ,
3887- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3888- return_softmax_stats =
3889- False , # return softmax stats is not supported for fp8 now
3890- scheduling_mode = scheduling_mode ,
3891- input_layout = input_layout ,
3892- sage_block_sizes = sage_block_sizes ,
3893- output_dtype = output_dtype ))
3894-
3895- # 128 < D <=256: KV_STEP = 128
3896- specs .append (
3897- kernel_spec (
3898- sm = sm ,
3899- sm_mma = 90 ,
3900- dtype = dtype ,
3901- seq_len = 0 , # support any sequence length
3902- head_size = [160 , 192 , 256 ],
3903- warps_m = 4 , #4x1 warpgroups
3904- warps_n = 1 ,
3905- version = 2 ,
3906- interleaved = False ,
3907- ldgsts_q =
3908- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3909- ldgsts_k = False ,
3910- ldgsts_v = False ,
3911- share_smem_k_v = False ,
3912- loop_step = 64 ,
3913- q_tile_buffers = 1 , # only used by warp specialized kernels
3914- has_noloop = 0 ,
3915- noloop_step = 64 ,
3916- kv_loop_step =
3917- 128 , # use 128 kv step size to avoid register spilling
3918- kv_tile_buffers = 2 , # only used by warp specialized kernels
3919- unroll_threshold = 1 ,
3920- has_scale_max = False ,
3921- flash_attention = True ,
3922- warp_specialization = True ,
3923- alibi = alibi ,
3924- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3925- return_softmax_stats =
3926- False , # return softmax stats is not supported for fp8 now
3927- scheduling_mode = scheduling_mode ,
3928- input_layout = input_layout ,
3929- sage_block_sizes = sage_block_sizes ,
3930- output_dtype = output_dtype ))
3931-
3932- # context MLA (192x128)
3933- # we could use param 'output_dtype' of enumerate_qgmma_flash_warpspec_kernels(),
3934- # but it will generate many unnecessary kernels and they are not easy to filter out.
3935- for output_type in [None , 'bf16' ]:
3824+ # for normal attention, we do not need return softmax for ws fp8 kernels currently.
3825+ # also fp8 input and bf16 output is only needed for MLA kernel.
3826+ skip_combination = return_softmax or (output_dtype is not None )
3827+ # for context mla, we need separate qkv as input layout when returning softmax.
3828+ skip_mla_combination = return_softmax and input_layout != InputLayout .SEPARATE_Q_K_V
3829+ if not skip_combination :
3830+ # D <= 64: KV_STEP = 256
3831+ specs .append (
3832+ kernel_spec (
3833+ sm = sm ,
3834+ sm_mma = 90 ,
3835+ dtype = dtype ,
3836+ seq_len = 0 , # support any sequence length
3837+ head_size = [32 , 40 , 48 , 64 ],
3838+ warps_m = 4 , #4x1 warpgroups
3839+ warps_n = 1 ,
3840+ version = 2 ,
3841+ interleaved = False ,
3842+ ldgsts_q =
3843+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3844+ ldgsts_k = False ,
3845+ ldgsts_v = False ,
3846+ share_smem_k_v = False ,
3847+ loop_step = 64 ,
3848+ q_tile_buffers = 1 , # only used by warp specialized kernels
3849+ has_noloop = 0 ,
3850+ noloop_step = 64 ,
3851+ kv_loop_step = 256 ,
3852+ kv_tile_buffers = 4 , # only used by warp specialized kernels
3853+ unroll_threshold = 1 ,
3854+ has_scale_max = False ,
3855+ flash_attention = True ,
3856+ warp_specialization = True ,
3857+ alibi = alibi ,
3858+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3859+ return_softmax_stats = return_softmax ,
3860+ scheduling_mode = scheduling_mode ,
3861+ input_layout = input_layout ,
3862+ sage_block_sizes = sage_block_sizes ,
3863+ output_dtype = output_dtype ))
3864+
3865+ # 64 < D <=128: KV_STEP = 128
3866+ specs .append (
3867+ kernel_spec (
3868+ sm = sm ,
3869+ sm_mma = 90 ,
3870+ dtype = dtype ,
3871+ seq_len = 0 , # support any sequence length
3872+ head_size = [80 , 96 , 104 , 128 ],
3873+ warps_m = 4 , #4x1 warpgroups
3874+ warps_n = 1 ,
3875+ version = 2 ,
3876+ interleaved = False ,
3877+ ldgsts_q =
3878+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3879+ ldgsts_k = False ,
3880+ ldgsts_v = False ,
3881+ share_smem_k_v = False ,
3882+ loop_step = 64 ,
3883+ q_tile_buffers = 1 , # only used by warp specialized kernels
3884+ has_noloop = 0 ,
3885+ noloop_step = 64 ,
3886+ kv_loop_step = 256 ,
3887+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3888+ unroll_threshold = 1 ,
3889+ has_scale_max = False ,
3890+ flash_attention = True ,
3891+ warp_specialization = True ,
3892+ alibi = alibi ,
3893+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3894+ return_softmax_stats = return_softmax ,
3895+ scheduling_mode = scheduling_mode ,
3896+ input_layout = input_layout ,
3897+ sage_block_sizes = sage_block_sizes ,
3898+ output_dtype = output_dtype ))
3899+
3900+ # 128 < D <=256: KV_STEP = 128
3901+ specs .append (
3902+ kernel_spec (
3903+ sm = sm ,
3904+ sm_mma = 90 ,
3905+ dtype = dtype ,
3906+ seq_len = 0 , # support any sequence length
3907+ head_size = [160 , 192 , 256 ],
3908+ warps_m = 4 , #4x1 warpgroups
3909+ warps_n = 1 ,
3910+ version = 2 ,
3911+ interleaved = False ,
3912+ ldgsts_q =
3913+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3914+ ldgsts_k = False ,
3915+ ldgsts_v = False ,
3916+ share_smem_k_v = False ,
3917+ loop_step = 64 ,
3918+ q_tile_buffers = 1 , # only used by warp specialized kernels
3919+ has_noloop = 0 ,
3920+ noloop_step = 64 ,
3921+ kv_loop_step =
3922+ 128 , # use 128 kv step size to avoid register spilling
3923+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3924+ unroll_threshold = 1 ,
3925+ has_scale_max = False ,
3926+ flash_attention = True ,
3927+ warp_specialization = True ,
3928+ alibi = alibi ,
3929+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3930+ return_softmax_stats = return_softmax ,
3931+ scheduling_mode = scheduling_mode ,
3932+ input_layout = input_layout ,
3933+ sage_block_sizes = sage_block_sizes ,
3934+ output_dtype = output_dtype ))
3935+
3936+ if not skip_mla_combination :
3937+ # context MLA (192x128)
39363938 specs .append (
39373939 kernel_spec (
39383940 sm = sm ,
@@ -3962,12 +3964,11 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
39623964 warp_specialization = True ,
39633965 alibi = alibi ,
39643966 enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3965- return_softmax_stats =
3966- False , # return softmax stats is not supported for fp8 now
3967+ return_softmax_stats = return_softmax ,
39673968 scheduling_mode = scheduling_mode ,
39683969 input_layout = input_layout ,
39693970 sage_block_sizes = sage_block_sizes ,
3970- output_dtype = output_type ))
3971+ output_dtype = output_dtype ))
39713972
39723973
39733974def enumerate_igmma_kernels (specs , sm = 90 ):
@@ -6215,6 +6216,10 @@ def enumerate_kernels():
62156216 enumerate_hgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'fp16' )
62166217 enumerate_hgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'bf16' )
62176218 enumerate_qgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'e4m3' )
6219+ enumerate_qgmma_flash_warpspec_kernels (specs ,
6220+ sm = 90 ,
6221+ dtype = 'e4m3' ,
6222+ output_dtype = "bf16" )
62186223
62196224 # For now SageAttention only needs BF16
62206225 # block_size_q should be divisible by 64
0 commit comments