diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 555ff06f698..2e437559fc5 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -51,24 +51,43 @@ constexpr std::tuple tile_size_fwd_sm90( return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } } else { - if (headdim <= 64) { - if (use_one_mma_wg) { + // FP8 path + if (use_one_mma_wg) { + // Decode tiles — independent of two-level accumulation setting + if (headdim <= 96) { return {64, 128, true, true}; } else { - return {192, 160, true, true}; - } - } else if (headdim <= 96) { - return {192, 128, true, true}; - } else if (headdim <= 128) { - if (use_one_mma_wg) { return {64, 96, true, true}; - } else{ - return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } - } else if (headdim <= 192) { - return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { - return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap + // Prefill tiles — two-level accumulation needs smaller tiles to reduce + // register pressure from the separate fp32 accumulator (tOrO_accum). + // Currently just optimized for causal case. +#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION + if (headdim <= 64) { + return {192, 128, true, true}; + } else if (headdim <= 96) { + return {128, 128, true, true}; + } else if (headdim <= 128) { + return {128, 192, true, true}; + } else if (headdim <= 192) { + return {128, 96, true, true}; + } else { + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // TODO: FP8 prefill ~0.54x of BF16 with two-level accum at hd256 + } +#else + if (headdim <= 64) { + return {192, 160, true, true}; + } else if (headdim <= 96) { + return {192, 128, true, true}; + } else if (headdim <= 128) { + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + } else if (headdim <= 192) { + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; + } else { + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; + } +#endif } } }