diff --git a/csrc/kernels/get_mla_metadata.cu b/csrc/kernels/get_mla_metadata.cu index 6b78f9b..d7f6bf4 100644 --- a/csrc/kernels/get_mla_metadata.cu +++ b/csrc/kernels/get_mla_metadata.cu @@ -37,7 +37,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[0] = (now_idx >= batch_size ? -1 : now_idx); tile_scheduler_metadata0[1] = now_block * block_size_n; tile_scheduler_metadata1 = now_n_split_idx; int remain_payload = payload; diff --git a/csrc/kernels/mla_combine.cu b/csrc/kernels/mla_combine.cu index b6ba8f8..225c321 100644 --- a/csrc/kernels/mla_combine.cu +++ b/csrc/kernels/mla_combine.cu @@ -26,7 +26,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int my_num_splits = end_split_idx - start_split_idx; FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); - if (my_num_splits == 1) { + if (my_num_splits <= 1) { return; } diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index ff29305..03a208b 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1022,7 +1022,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; + if (begin_idx >= params.b || begin_idx < 0) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); // Copy the first Q