diff --git a/aiter/mla.py b/aiter/mla.py index 6f4cd2150a..1e09e1bd51 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -162,6 +162,8 @@ def mla_decode_fwd( q_scale=None, kv_scale=None, intra_batch_mode=False, + return_logits=False, + return_lse=False, ): device = q.device assert logit_cap <= 0, f"{logit_cap=} is not support yet" @@ -271,7 +273,7 @@ def mla_decode_fwd( ): # Natively support cases pass - elif nhead in range(32, 128 + 1, 16) and persistent_mode and max_seqlen_q == 1: + elif nhead in range(32, 128 + 1, 16) and persistent_mode: # we use nhead=16 to simulate such cases by customized metadata # metadata also views qo's tensor as shape (total_s * (nhead // 16), 16, ...) total_s = ori_total_s * (ori_nhead // 16) @@ -292,7 +294,11 @@ def mla_decode_fwd( dtype=dtypes.fp32, device=device, ) - final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + final_lse = ( + torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + if return_lse + else None + ) aiter.mla_decode_stage1_asm_fwd( q, @@ -326,10 +332,9 @@ def mla_decode_fwd( ) if io_transformed: - if persistent_mode: + if return_logits: logits = logits.view(-1, 1, ori_nhead, v_head_dim) - else: - logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim) + q = q.view(ori_total_s, ori_nhead, -1) o = o.view(ori_total_s, ori_nhead, -1) diff --git a/aiter/ops/attention.py b/aiter/ops/attention.py index 20101480eb..a433bd213a 100644 --- a/aiter/ops/attention.py +++ b/aiter/ops/attention.py @@ -629,7 +629,8 @@ def get_mla_metadata_info_v1( max_qo_tiles_per_batch = ( int(math.ceil(max_seqlen_qo * num_head_qo / 128)) - if num_head_qo == 16 or (num_head_qo == 128 and kv_dtype == dtypes.fp8) + if num_head_qo == 16 + or (num_head_qo == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8) else int(math.ceil(max_seqlen_qo * num_head_qo / 16)) ) batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size diff --git a/csrc/kernels/mla/metadata/v1_2_device.cuh b/csrc/kernels/mla/metadata/v1_2_device.cuh index ad64bce238..b96051874d 100644 --- a/csrc/kernels/mla/metadata/v1_2_device.cuh +++ b/csrc/kernels/mla/metadata/v1_2_device.cuh @@ -28,12 +28,34 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ { using QoState = QoState; + const int32_t ori_seqlen_qo = [&]() { + if constexpr (Traits::kIsSparse) + { + return params.p_seqlens_qo_indptr[1] - params.p_seqlens_qo_indptr[0]; + } + else + { + return params.ori_seqlen_qo; + } + }(); + + const int32_t num_batches = [&]() { + if constexpr (Traits::kIsSparse) + { + return params.num_batches * ori_seqlen_qo; + } + else + { + return params.num_batches; + } + }(); + extern __shared__ uint8_t p_smem[]; int32_t* p_lds_seqlens_qo = reinterpret_cast(p_smem); - int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : params.num_batches); + int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : num_batches); QoState qo_state( - params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr); + params.uni_seqlen_qo, ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr); auto get_num_qo_tiles = [&](const int32_t batch_idx) { if constexpr(Traits::kQoSplits) @@ -53,10 +75,10 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ MlaWorkInfo* p_work_info_set = reinterpret_cast(params.p_work_info_set_raw); int32_t sum_blocks = 0; - for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size()) + for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size()) { const int32_t bid_ori = Traits::kIsSparse - ? (bid / params.ori_seqlen_qo / params.qk_batch_ratio) + ? (bid / ori_seqlen_qo / params.qk_batch_ratio) : (bid / params.qk_batch_ratio); const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; const int32_t seqlen_kv = @@ -119,7 +141,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ for(int32_t cid = 0; cid < params.num_cu; ++cid) { int32_t remain_payload = payload; - while(curr_batch < params.num_batches) + while(curr_batch < num_batches) { const int32_t num_qo_tiles = get_num_qo_tiles(curr_batch); const int32_t qo_tile_size = @@ -143,9 +165,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size, qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) + { + if (params.qk_batch_ratio != 1) + { + batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + } + } work_info.kv_end = ck_tile::min( work_info.kv_start + (remain_kv_blocks * params.kv_granularity), - curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx)); + curr_kv_end - batch_tail); work_info.kv_offset = curr_kv_end - work_info.kv_end; // split related info @@ -202,7 +232,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ curr_sub_head_idx = (curr_sub_head_idx == (params.qk_batch_ratio - 1)) ? 0 : (curr_sub_head_idx + 1); - if(curr_batch < params.num_batches) + if(curr_batch < num_batches) { if(curr_sub_head_idx == 0) { @@ -213,7 +243,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ else { const int32_t bid_ori = Traits::kIsSparse - ? (curr_batch / params.ori_seqlen_qo / + ? (curr_batch / ori_seqlen_qo / params.qk_batch_ratio) : (curr_batch / params.qk_batch_ratio); curr_kv_seqlen = params.p_seqlens_kv_indptr[bid_ori + 1] - @@ -251,9 +281,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__ qo_state.get_end(curr_batch)); work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity); + int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx); + if constexpr(!Traits::kIsSparse) + { + if (params.qk_batch_ratio != 1) + { + batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1; + } + } work_info.kv_end = ck_tile::min( work_info.kv_start + (consuming_blks * params.kv_granularity), - curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx)); + curr_kv_end - batch_tail); work_info.kv_offset = curr_kv_end - work_info.kv_end; work_info.partial_qo_loc = partial_idx; p_work_info_set[num_works] = work_info; @@ -365,12 +403,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba num_batches *= qk_batch_ratio; } - if(is_sparse) - { - num_batches *= uni_seqlen_qo; - uni_seqlen_qo = 1; - } - TORCH_CHECK((num_heads == 16) || (num_heads == 128), __func__, ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index efe8b47f71..0307082441 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -19,6 +19,12 @@ # qdtype fp8, kdtype fp8: nhead16, nhead128 +def check_support(dtype, kv_dtype, nhead): + if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16: + return False + return True + + def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: @@ -445,11 +451,11 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 - if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead in [16, 128]: + if dtype == torch.bfloat16 and nhead in [16, 128]: err, us_asm_decode = test_absorb_decode_bf16() - elif kvtype == dtypes.fp8 and nhead in [16, 128]: err, us_asm_decode = test_absorb_decode_fp8() + ret["decode:err"] = err ret["decode:asm_576"] = us_asm_decode @@ -599,22 +605,23 @@ def test_absorb_decode_fp8(): for dtype, kvtype, ctx_len, batch_size, split_per_batch in itertools.product( list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.split_per_batch ): - ret = test_mla( - ctx_len, - batch_size, - nhead, - args.kv_lora_rank, - args.qk_nope_head_dim, - args.qk_rope_head_dim, - args.v_head_dim, - dtype, - kvtype, - args.block_size, - varlen=args.varlen, - decode_qlen=decode_qlen, - split_per_batch=split_per_batch, - ) - df.append(ret) + if check_support(dtype, kvtype, nhead): + ret = test_mla( + ctx_len, + batch_size, + nhead, + args.kv_lora_rank, + args.qk_nope_head_dim, + args.qk_rope_head_dim, + args.v_head_dim, + dtype, + kvtype, + args.block_size, + varlen=args.varlen, + decode_qlen=decode_qlen, + split_per_batch=split_per_batch, + ) + df.append(ret) df = pd.DataFrame(df) # df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv") aiter.logger.info(f"summary:\n{df}") diff --git a/op_tests/test_mla_persistent.py b/op_tests/test_mla_persistent.py index 03b8695b91..68d558048a 100644 --- a/op_tests/test_mla_persistent.py +++ b/op_tests/test_mla_persistent.py @@ -18,6 +18,12 @@ # qdtype fp8, kdtype bf16: nhead16 +def check_support(dtype, kv_dtype, nhead): + if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16: + return False + return True + + def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: @@ -401,15 +407,9 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 - if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and ( - (nhead in [16]) or (decode_qlen == 1 and nhead in range(32, 128 + 1, 16)) - ): + if dtype == torch.bfloat16: err, us_asm_decode = test_absorb_decode_bf16() - elif kvtype == dtypes.fp8 and ( - (dtype == dtypes.fp8 and nhead in [16, 128]) - or (dtype == dtypes.bf16 and nhead in [16]) - or (decode_qlen == 1 and nhead in range(32, 128 + 1, 16)) - ): + elif kvtype == dtypes.fp8: err, us_asm_decode = test_absorb_decode_fp8() ret["decode:err"] = err ret["decode:asm_576"] = us_asm_decode @@ -566,23 +566,24 @@ def test_absorb_decode_fp8(): for dtype, kvtype, ctx_len, batch_size, max_split_per_batch in itertools.product( list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.max_split_per_batch ): - ret = test_mla( - ctx_len, - batch_size, - nhead, - args.kv_lora_rank, - args.qk_nope_head_dim, - args.qk_rope_head_dim, - args.v_head_dim, - dtype, - kvtype, - args.block_size, - varlen=args.varlen, - decode_qlen=decode_qlen, - max_split_per_batch=max_split_per_batch, - non_persistent_mode=args.non_persistent_mode, - ) - df.append(ret) + if check_support(dtype, kvtype, nhead): + ret = test_mla( + ctx_len, + batch_size, + nhead, + args.kv_lora_rank, + args.qk_nope_head_dim, + args.qk_rope_head_dim, + args.v_head_dim, + dtype, + kvtype, + args.block_size, + varlen=args.varlen, + decode_qlen=decode_qlen, + max_split_per_batch=max_split_per_batch, + non_persistent_mode=args.non_persistent_mode, + ) + df.append(ret) df = pd.DataFrame(df) # df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv") aiter.logger.info(f"summary:\n{df}") diff --git a/op_tests/test_mla_sparse.py b/op_tests/test_mla_sparse.py index c93170b0c5..6372f9e9de 100644 --- a/op_tests/test_mla_sparse.py +++ b/op_tests/test_mla_sparse.py @@ -20,6 +20,12 @@ # qdtype fp8, kdtype bf16: nhead16 +def check_support(dtype, kv_dtype, nhead): + if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16: + return False + return True + + def cal_diff( x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False ) -> None: @@ -450,8 +456,8 @@ def test_mla( reduce_final_map, reduce_partial_map, kv_granularity=max(page_size, 16), - max_seqlen_qo=int(max_seqlen_qo), - uni_seqlen_qo=decode_qlen, + max_seqlen_qo=1, + uni_seqlen_qo=1, fast_mode=True, max_split_per_batch=max_split_per_batch, topk=2048, @@ -525,7 +531,7 @@ def test_sparse_mla_bf16(): ) return err, us_asm_decode - def test_absorb_decode_fp8(): + def test_sparse_mla_fp8(): if dtype != dtypes.fp8 and nhead == 128: aiter.logger.info("don't support this case:\n") return None, 1e12 @@ -597,16 +603,10 @@ def test_absorb_decode_fp8(): err = None us_asm_decode = 1e12 - if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and ( - (nhead in [16]) or (max_seqlen_qo == 1 and nhead in range(32, 128 + 1, 16)) - ): + if dtype == torch.bfloat16: err, us_asm_decode = test_sparse_mla_bf16() - elif kvtype == dtypes.fp8 and ( - (dtype == dtypes.fp8 and nhead in [16, 128]) - or (dtype == dtypes.bf16 and nhead in [16]) - or (decode_qlen == 1 and nhead in range(32, 128 + 1, 16)) - ): - err, us_asm_decode = test_absorb_decode_fp8() + elif kvtype == dtypes.fp8: + err, us_asm_decode = test_sparse_mla_fp8() ret["decode:err"] = err ret["decode:asm_576"] = us_asm_decode @@ -684,7 +684,7 @@ def test_absorb_decode_fp8(): type=str, choices=["bf16", "fp8"], nargs="*", - default=["bf16"], + default=["bf16", "fp8"], help="""Data type of Q. e.g.: -d bf16""", ) @@ -694,7 +694,7 @@ def test_absorb_decode_fp8(): type=str, choices=["bf16", "fp8"], nargs="*", - default=["bf16"], + default=["bf16", "fp8"], help="""Data type of KV. e.g.: -kvd bf16""", ) @@ -731,7 +731,7 @@ def test_absorb_decode_fp8(): "--max_split_per_batch", type=int, nargs="*", - default=[16], + default=[32], help="""kv seqlens max split num for per batch. e.g.: -ms 32""", ) @@ -755,22 +755,23 @@ def test_absorb_decode_fp8(): for dtype, kvtype, ctx_len, batch_size, max_split_per_batch in itertools.product( list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.max_split_per_batch ): - ret = test_mla( - ctx_len, - batch_size, - nhead, - args.kv_lora_rank, - args.qk_nope_head_dim, - args.qk_rope_head_dim, - args.v_head_dim, - dtype, - kvtype, - args.block_size, - varlen=args.varlen, - decode_qlen=decode_qlen, - max_split_per_batch=max_split_per_batch, - ) - df.append(ret) + if check_support(dtype, kvtype, nhead): + ret = test_mla( + ctx_len, + batch_size, + nhead, + args.kv_lora_rank, + args.qk_nope_head_dim, + args.qk_rope_head_dim, + args.v_head_dim, + dtype, + kvtype, + args.block_size, + varlen=args.varlen, + decode_qlen=decode_qlen, + max_split_per_batch=max_split_per_batch, + ) + df.append(ret) df = pd.DataFrame(df) # df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv") aiter.logger.info(f"summary:\n{df}")