Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6b1d059
Support ROCM builds from source distribution, and improve error handl…
mgorny Jan 18, 2025
cd393e0
[Build] Update version of setuptools used to generate core package (#…
tmm1 Jan 29, 2025
bb135af
Don't compile for CUDA 11, compile for official pytorch 2.6.0
tridao Jan 29, 2025
979702c
Bump to v2.7.4
tridao Jan 29, 2025
5231d95
Drop Pytorch 2.1
tridao Jan 29, 2025
454ce31
[FA3] Compile with nvcc 12.8 instead of 12.3
tridao Jan 29, 2025
803f609
Fix comment in assert
tridao Jan 30, 2025
02541ac
[CE] Assert logit_scale > 0
tridao Jan 30, 2025
2a20412
Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128
tridao Feb 3, 2025
6d199aa
Fix shape_O in epilogue params when kHeadDimV != kHeadDim
tridao Feb 4, 2025
86bcd05
Remove old combine.h
tridao Feb 4, 2025
e3b2400
Fix loading paged V when kHeadDimV != kHeadDim
tridao Feb 4, 2025
9e07d6d
Fix shape_V for storing new KV when kHeadDimV != kHeadDim
tridao Feb 4, 2025
f0f2523
Implement the case of LargeHeadDimV
tridao Feb 4, 2025
4c8819d
Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192
tridao Feb 7, 2025
dd87691
Pass _1 or _0 to cute::aligned_struct
tridao Feb 8, 2025
ed53b5f
Fix compilation for FP8 when kHeadDimV != kHeadDim
tridao Feb 8, 2025
4e8496a
Support Qv
tridao Feb 8, 2025
893a22a
Test varlen_q=True by default for kvcache
tridao Feb 8, 2025
5fab938
Fix num_splits heuristic being called before get_pack_gqa
tridao Feb 8, 2025
5fc5ebf
Fix num_splits heuristic again when PackGQA
tridao Feb 8, 2025
5378bc3
Tile fwd_combine kernel along headdim, don't need kBlockM > 128
tridao Feb 8, 2025
db8ca79
Use bf16 instead of fp16 in benchmark_gemm.py
tridao Feb 9, 2025
982c480
Update Cutlass to 3.7
tridao Feb 9, 2025
58ebfa5
Use nvcc 12.6 but ptxas 12.8
tridao Feb 9, 2025
ed435c6
cicc uses the same version as ptxas
tridao Feb 9, 2025
8668823
Split hdimdiff into a separate translation unit
tridao Feb 9, 2025
b2fc79d
Update benchmark script
tridao Feb 9, 2025
c091545
Update Cutlass to 3.8
tridao Feb 9, 2025
5e39b10
Adjust tile size for hdim 64
tridao Feb 9, 2025
1a7f4df
Adjust ninja build file
tridao Feb 10, 2025
f12ed13
Merge remote-tracking branch 'upstream/main' into lwilkinson/upstream…
LucasWilkinson Feb 11, 2025
21fcade
build head diff + fix build errors
LucasWilkinson Feb 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
21 changes: 9 additions & 12 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,16 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
cuda-version: ['11.8.0', '12.3.2']
torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0']
cuda-version: ['12.4.1']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch < 2.5 does not support Python 3.13
- torch-version: '2.1.2'
python-version: '3.13'
- torch-version: '2.2.2'
python-version: '3.13'
- torch-version: '2.3.1'
Expand Down Expand Up @@ -113,7 +108,7 @@ jobs:
run: |
pip install --upgrade pip
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools==68.0.0
pip install setuptools==75.8.0
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
Expand All @@ -122,8 +117,8 @@ jobs:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
Expand All @@ -149,7 +144,7 @@ jobs:
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==68.0.0
pip install setuptools==75.8.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
Expand Down Expand Up @@ -203,7 +198,9 @@ jobs:

- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
pip install ninja packaging wheel twine
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
pip install setuptools==75.8.0
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu

Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,18 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# BF16 source files
file(GLOB FA3_BF16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp
recursive-include csrc *.py

recursive-include vllm_flash_attn *.cu
recursive-include vllm_flash_attn *.h
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Currently released:

Requirements: H100 / H800 GPU, CUDA >= 12.3.

For now, we highly recommend CUDA 12.3 for best performance.
We highly recommend CUDA 12.8 for best performance.

To install:
```sh
Expand All @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func()
## Installation and features
**Requirements:**
- CUDA toolkit or ROCm toolkit
- PyTorch 1.12 and above.
- PyTorch 2.2 and above.
- `packaging` Python package (`pip install packaging`)
- `ninja` Python package (`pip install ninja`) *
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
Expand Down Expand Up @@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation

### NVIDIA CUDA Support
**Requirements:**
- CUDA 11.7 and above.
- CUDA 12.0 and above.

We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs

torch.manual_seed(0)
repeats = 30
dtype = torch.float16
dtype = torch.bfloat16
device = 'cuda'
verbose = False
m, n = 8192, 8192
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 2220 files
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.7.3"
__version__ = "2.7.4.post1"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down
1 change: 1 addition & 0 deletions flash_attn/ops/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def forward(
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
labels = F.pad(labels, (0, 1))[..., :-1]
assert labels.data_ptr() % 16 == 0
assert logit_scale > 0.0
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
Expand Down
60 changes: 19 additions & 41 deletions hopper/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)


def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)):
def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):
if causal:
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
else:
Expand All @@ -67,7 +67,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=
col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
avg_seqlen = (col_right - col_left + 1).float().mean().item()
return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)


def convert_to_cudnn_type(torch_type):
Expand Down Expand Up @@ -242,21 +242,6 @@ def run(*args, **kwargs):
time_f = {}
time_b = {}

# tflops_matmul = {}
# m, n = 8192, 8192
# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
# a = torch.randn(m, k, device=device, dtype=dtype)
# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
# nFLOPS_matmul = 2 * m * n * k
# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS')
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12
# # import pickle
# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp:
# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL)
# exit(0)

# for headdim in [64, 128, 256]:
# for headdim in [64, 96, 128, 192]:
# for headdim in [64, 96, 128, 192, 256]:
Expand All @@ -272,9 +257,11 @@ def run(*args, **kwargs):
# headdim = 128
nheads_kv = nheads
# nheads_kv = nheads // 4
headdim_v = headdim
# headdim_v = 128

for batch_size, seqlen in bs_seqlen_vals:
num_splits = 1
num_splits = 0
window_size = (-1, -1)
# window_size = (seqlen // 2 - 1, 0)
sink_token_length = 0
Expand All @@ -285,20 +272,16 @@ def run(*args, **kwargs):
# leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)
q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
v_fa3 = v if not V_colmajor else v_colmajor
# q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
# k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
# v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype)
g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)
stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)
a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen)
b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2)
# x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype)
# w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2)
if varlen:
q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]]
cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
Expand All @@ -318,16 +301,16 @@ def run(*args, **kwargs):
page_table = None

for causal in [False, True]:
# for causal in [False]:
# for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size)
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
# _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
if dtype != torch.float8_e4m3fn:
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
# if False:
if not varlen:
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
Expand All @@ -343,7 +326,7 @@ def run(*args, **kwargs):
repeats=repeats, verbose=False, desc='Fav2')
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
if headdim <= 256 and dtype != torch.float8_e4m3fn:
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
if triton_attention is not None:
qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
Expand All @@ -356,7 +339,7 @@ def run(*args, **kwargs):
# # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
Expand All @@ -375,12 +358,7 @@ def run(*args, **kwargs):
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
# time.sleep(1)
# m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False)
# nFLOPS_matmul = nFLOPS
# nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1]
# m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS')
if dtype != torch.float8_e4m3fn:
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
time.sleep(1)
if not varlen:
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
Expand All @@ -396,11 +374,11 @@ def run(*args, **kwargs):
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')

if dtype != torch.float8_e4m3fn:
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
# if False:
print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
if headdim <= 256 and dtype != torch.float8_e4m3fn:
if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:
if triton_attention is not None:
print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')
# if causal:
Expand All @@ -409,7 +387,7 @@ def run(*args, **kwargs):
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn:
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
# benchmark_forward(torch.square, k)
# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
Expand Down
Loading