Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
6b1d059
Support ROCM builds from source distribution, and improve error handl…
mgorny Jan 18, 2025
cd393e0
[Build] Update version of setuptools used to generate core package (#…
tmm1 Jan 29, 2025
bb135af
Don't compile for CUDA 11, compile for official pytorch 2.6.0
tridao Jan 29, 2025
979702c
Bump to v2.7.4
tridao Jan 29, 2025
5231d95
Drop Pytorch 2.1
tridao Jan 29, 2025
454ce31
[FA3] Compile with nvcc 12.8 instead of 12.3
tridao Jan 29, 2025
803f609
Fix comment in assert
tridao Jan 30, 2025
02541ac
[CE] Assert logit_scale > 0
tridao Jan 30, 2025
2a20412
Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128
tridao Feb 3, 2025
6d199aa
Fix shape_O in epilogue params when kHeadDimV != kHeadDim
tridao Feb 4, 2025
86bcd05
Remove old combine.h
tridao Feb 4, 2025
e3b2400
Fix loading paged V when kHeadDimV != kHeadDim
tridao Feb 4, 2025
9e07d6d
Fix shape_V for storing new KV when kHeadDimV != kHeadDim
tridao Feb 4, 2025
f0f2523
Implement the case of LargeHeadDimV
tridao Feb 4, 2025
4c8819d
Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192
tridao Feb 7, 2025
dd87691
Pass _1 or _0 to cute::aligned_struct
tridao Feb 8, 2025
ed53b5f
Fix compilation for FP8 when kHeadDimV != kHeadDim
tridao Feb 8, 2025
4e8496a
Support Qv
tridao Feb 8, 2025
893a22a
Test varlen_q=True by default for kvcache
tridao Feb 8, 2025
5fab938
Fix num_splits heuristic being called before get_pack_gqa
tridao Feb 8, 2025
5fc5ebf
Fix num_splits heuristic again when PackGQA
tridao Feb 8, 2025
5378bc3
Tile fwd_combine kernel along headdim, don't need kBlockM > 128
tridao Feb 8, 2025
db8ca79
Use bf16 instead of fp16 in benchmark_gemm.py
tridao Feb 9, 2025
982c480
Update Cutlass to 3.7
tridao Feb 9, 2025
58ebfa5
Use nvcc 12.6 but ptxas 12.8
tridao Feb 9, 2025
ed435c6
cicc uses the same version as ptxas
tridao Feb 9, 2025
8668823
Split hdimdiff into a separate translation unit
tridao Feb 9, 2025
b2fc79d
Update benchmark script
tridao Feb 9, 2025
c091545
Update Cutlass to 3.8
tridao Feb 9, 2025
5e39b10
Adjust tile size for hdim 64
tridao Feb 9, 2025
1a7f4df
Adjust ninja build file
tridao Feb 10, 2025
15cf7ee
Rename collective_mainloop -> mainloop, move tile_scheduler variable
tridao Feb 11, 2025
9f313c7
Move functions getting number of m/n blocks to a separate file
tridao Feb 12, 2025
eafd53c
Update cutlass 3.8 to fix error w cudaGetDriverEntryPointByVersion
tridao Feb 12, 2025
fa445ff
Fix FP8 test
tridao Feb 12, 2025
a09abcd
make seqused optional on top level interface (#1497)
vasqu Feb 16, 2025
40cbd52
Temporarily change package name of FA3 to allow FA2 & FA3 install
tridao Feb 18, 2025
91917b4
Update benchmark_split_kv.py to work w new API
tridao Feb 18, 2025
ea3ecea
Add tp_degree to benchmark_split_kv
tridao Feb 18, 2025
74dfa43
Fix divide by 0 in causal tile_scheduler for large seqlen
tridao Feb 19, 2025
b36ad4e
Use split for super long sequences that don't fit into L2
tridao Feb 19, 2025
ecdb528
Make rotary test optional in FA3
tridao Feb 22, 2025
06e34f6
Enable MLA flag in FA3 (rope=64, latent=512) (#1504)
tzadouri Feb 23, 2025
6aed835
Add simple script to benchmark MLA decode
tridao Feb 24, 2025
6752d62
Add dynamic splits
tridao Feb 24, 2025
cdda5bf
Update to Cutlass 3.8.0 tag
tridao Feb 24, 2025
9505c74
Adjust seqlen_q in MLA decode benchmark script
tridao Feb 24, 2025
3b5047d
Fix loop in prepare_scheduler.cu (h/t Jay Shah)
tridao Feb 25, 2025
dec83a1
fix: add "typename" prior to dependent type name (#1517)
zhiweij1 Feb 28, 2025
08f4c80
Add FLOPS to MLA decode benchmark
tridao Feb 28, 2025
085ce58
Change margin in prepare_scheduler.cu from 20% to 10%
tridao Feb 28, 2025
39e7197
Fix cuda 12.1 build (#1511)
LucasWilkinson Mar 1, 2025
20b84d6
Don't use IntraWGOverlap for hdim 64,512
tridao Mar 2, 2025
5458c78
Remove sink token
tridao Mar 2, 2025
6865e60
fix: prompt index to type longlong to avoid numerical overflow (#1500)
xin-w8023 Mar 2, 2025
45c48af
Add option for WG1 to use RS MMA but WG2 using SS MMA
tridao Mar 4, 2025
3edf7e0
Add kwargs to _write_ninja_file for compatibility with new torch
tridao Mar 4, 2025
4f0640d
Move writing P to smem as separate function
tridao Mar 5, 2025
d82bbf2
Fix causal scheduler not considering hdim_v != hdim
tridao Mar 5, 2025
9c036e4
Always split fwd_combine_kernel on batch
tridao Mar 7, 2025
81643fa
For each batch, if num_splits=1, write to O instead of O_partial
tridao Mar 8, 2025
1d30bb4
Enable TMA when page size is a multiple of kBlockN
tridao Mar 9, 2025
a3a9cc5
Update ptxas to 12.8.93 (i.e. 12.8.1)
tridao Mar 9, 2025
322bec9
Use tile size 192 x 128 for hdim 64 causal
tridao Mar 9, 2025
5639b9d
Update benchmark_mla_decode.py
tridao Mar 9, 2025
48b3acb
Benchmark MHA, GQA, MQA, MLA in the same script
tridao Mar 11, 2025
d904855
Benchmark FlashMLA if it's available
tridao Mar 11, 2025
cdaf2de
Run all 4 attn variants in benchmark
tridao Mar 12, 2025
cf1b809
Move scheduler.get_next_work to before the epilogue
tridao Mar 12, 2025
3cf8998
Enable Cluster for hdim128 back
tridao Mar 12, 2025
6063dc5
Move tOrO init in mainloop
tridao Mar 12, 2025
430954a
Adjust heuristic for get_pagedkv_tma
tridao Mar 12, 2025
000090d
Enable PDL
tridao Mar 13, 2025
46e1d4a
Simplify prepare_varlen_num_blocks_kernel, restrict to batch <= 992
tridao Mar 13, 2025
897c845
Fix: num_splits_dynamic_ptr needs to be set before get_num_splits
tridao Mar 14, 2025
90f27a2
Loop on num_splits instead of parameterizing it in kvcache test
tridao Mar 15, 2025
fa60e7c
Add option to precompute scheduler metadata
tridao Mar 15, 2025
6c87fac
Update MLA decode benchmark to use get_scheduler_metadata
tridao Mar 15, 2025
4b5eeab
Fix FP8 test to quantize KV cache for reference impl as well
tridao Mar 15, 2025
27f501d
Dynamic autotune configs for devices with warp size != 32 (#1534)
schung-amd Mar 15, 2025
1012435
Merge remote-tracking branch 'upstream/main' into lwilkinson/upstream…
LucasWilkinson Mar 20, 2025
e0faa9a
update binding
LucasWilkinson Mar 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.26)

project(vllm_flash_attn LANGUAGES CXX)
project(vllm_flash_attn LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)

Expand Down Expand Up @@ -213,7 +213,9 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
SRCS "${FA3_GEN_SRCS}"
CUDA_ARCHS "${FA3_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "hopper/flash_fwd_combine.cu"
SRCS
hopper/flash_fwd_combine.cu
hopper/flash_prepare_scheduler.cu
CUDA_ARCHS "${FA3_ARCHS}")
endif()

Expand All @@ -223,6 +225,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
LANGUAGE ${VLLM_GPU_LANG}
SOURCES
hopper/flash_fwd_combine.cu
hopper/flash_prepare_scheduler.cu
hopper/flash_api.cpp
hopper/flash_api_torch_lib.cpp
${FA3_GEN_SRCS}
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ inline __device__ void compute_dot_do_o(const Params &params) {
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;

Expand Down Expand Up @@ -205,7 +205,7 @@ inline __device__ void convert_dQ(const Params &params, const int nsplits) {
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;

Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Expand Down
8 changes: 4 additions & 4 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
// if (cute::thread0()) { print(tOrP); }
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
Expand Down Expand Up @@ -424,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi

// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}

Expand Down Expand Up @@ -942,7 +942,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));

FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);

Expand Down Expand Up @@ -1002,7 +1002,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));

FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
Expand Down
31 changes: 15 additions & 16 deletions flash_attn/ops/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
import triton
import triton.language as tl

def triton_autotune_configs():
# Return configs with a valid warp count for the current device
configs=[]
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block=1024
# Default to warp size 32 if not defined by device
warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
warp_count=1
while warp_count*warp_size <= max_threads_per_block:
configs.append(triton.Config({}, num_warps=warp_count))
warp_count*=2
return configs

def layer_norm_ref(
x,
Expand Down Expand Up @@ -126,14 +139,7 @@ def rms_norm_ref(


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=triton_autotune_configs(),
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down Expand Up @@ -393,14 +399,7 @@ def _layer_norm_fwd(


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=triton_autotune_configs(),
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
Expand Down
28 changes: 18 additions & 10 deletions hopper/benchmark_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import namedtuple
from functools import partial
import math
import os
from typing import NamedTuple
import torch
import torch.nn as nn
Expand Down Expand Up @@ -34,6 +35,8 @@
triton_attention = None
triton_attention = None

DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"


def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
# # Warmup
Expand All @@ -53,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
# time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)
# # return time_f[1].mean
# return time_f[1]
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3)


def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):
Expand Down Expand Up @@ -250,21 +253,24 @@ def run(*args, **kwargs):
# for headdim in [64, 96, 128, 192, 256]:
for headdim in [128]:
nheads = dim // headdim
# nheads = 128
# headdim = 64
# batch_size = 64
# seqlen = 512
# nheads = 8
# headdim = 128
nheads_kv = nheads
# nheads_kv = nheads // 4
# nheads_kv = 1
headdim_v = headdim
# headdim_v = 128
# headdim_v = 512
has_qv = headdim == 64 and headdim_v == 512
# has_qv = False

for batch_size, seqlen in bs_seqlen_vals:
num_splits = 0
window_size = (-1, -1)
# window_size = (seqlen // 2 - 1, 0)
sink_token_length = 0
pack_gqa = None
# seqlen_q = 64
seqlen_q = seqlen
Expand All @@ -276,6 +282,7 @@ def run(*args, **kwargs):
q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
v_fa3 = v if not V_colmajor else v_colmajor
qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None
# q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
# k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype)
Expand Down Expand Up @@ -303,7 +310,7 @@ def run(*args, **kwargs):
for causal in [False, True]:
# for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size)
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
Expand Down Expand Up @@ -351,17 +358,17 @@ def run(*args, **kwargs):

time.sleep(1)
if not varlen:
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
else:
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:
time.sleep(1)
if not varlen:
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav3')
else:
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
Expand All @@ -387,7 +394,7 @@ def run(*args, **kwargs):
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:
print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
# benchmark_forward(torch.square, k)
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
Expand All @@ -397,7 +404,8 @@ def run(*args, **kwargs):
# import pickle
# # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
# # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp:
# with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp:
# with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp:
# # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp:
# # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:
# # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:
# pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
Loading
Loading