Skip to content

Commit

Permalink
bugfix: remove 2x2 warp layout introduced in #518 (#523)
Browse files Browse the repository at this point in the history
#518 adds 2x2 warp layout which cause some numerical error for query
length between (16, 32].

This PR remove the 2x2 warp layout configuration. We might add it back
after we figured out where the bug is.

cc @abcdabcd987
  • Loading branch information
yzh119 authored Oct 11, 2024
1 parent 0aa4726 commit d0a1d0d
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 18 deletions.
2 changes: 1 addition & 1 deletion flashinfer-aot/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 32, 16]
cta_tile_q_choice = [128, 64, 16]

def get_insts(attention_variant, dtype_out):
return "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion flashinfer-aot/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 32, 16]
cta_tile_q_choice = [128, 64, 16]
def get_insts(attention_variant, dtype_out):
return "\n".join(
[
Expand Down
10 changes: 3 additions & 7 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ using mma::MMAMode;
constexpr uint32_t WARP_SIZE = 32;

constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) {
if (cta_tile_q > 32) {
if (cta_tile_q > 16) {
return 4;
} else if (cta_tile_q > 16) {
return 2;
} else {
return 1;
}
Expand Down Expand Up @@ -1324,11 +1322,9 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (unpacked_qo_len > 32) {
if (unpacked_qo_len > 16) {
// avg_packed_qo_len <= 64
cta_tile_q = 64;
} else if (unpacked_qo_len > 16) {
// avg_packed_qo_len <= 32
cta_tile_q = 32;
} else {
// avg_packed_qo_len <= 16
cta_tile_q = 16;
Expand Down
6 changes: 2 additions & 4 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,9 @@ PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 32) {
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
cta_tile_q = 64;
} else if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 32
cta_tile_q = 32;
} else {
// avg_packed_qo_len <= 16
cta_tile_q = 16;
Expand Down
5 changes: 0 additions & 5 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr uint32_t CTA_TILE_Q = 32; \
__VA_ARGS__ \
break; \
} \
case 16: { \
constexpr uint32_t CTA_TILE_Q = 16; \
__VA_ARGS__ \
Expand Down

0 comments on commit d0a1d0d

Please sign in to comment.