Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions benchmarks/benchmark_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# Adding is faster than masked_fill_
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
Expand Down Expand Up @@ -88,53 +88,65 @@ def time_fwd_bwd(func, *args, **kwargs):
speed_f = {}
speed_b = {}
speed_f_b = {}

for causal in causal_vals:
for headdim in headdim_vals:
for batch_size, seqlen in bs_seqlen_vals:
config = (causal, headdim, batch_size, seqlen)
nheads = dim // headdim
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
f, b = time_fwd_bwd(
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b

try:
qkv = qkv.detach().requires_grad_(True)

# FlashAttention 2
if "Flash2" in methods:
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal,
repeats=repeats, verbose=False
)
except: # Skip if OOM
f, b = float('nan'), float('nan')
time_f[config, "Pytorch"] = f
time_b[config, "Pytorch"] = b

if attention_triton is not None:
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
# Try both values of sequence_parallel and pick the faster one
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b

# PyTorch baseline
if "Pytorch" in methods:
try:
# fresh tensor avoids grad-history reuse issues
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,
device=device, dtype=dtype, requires_grad=True)
f, b = time_fwd_bwd(
attention_pytorch, qkv, dropout_p, causal=causal,
repeats=repeats, verbose=False
)
except Exception:
f, b = float('nan'), float('nan')
time_f[config, "Pytorch"] = f
time_b[config, "Pytorch"] = b

# Triton
if "Triton" in methods and attention_triton is not None:
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
# Try both values of sequence_parallel and pick the faster backward
try:
f, b = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
False, repeats=repeats, verbose=False
)
except:
except Exception:
f, b = float('nan'), float('inf')
try:
_, b0 = time_fwd_bwd(
attention_triton, q, k, v, causal, headdim**(-0.5),
True, repeats=repeats, verbose=False
)
except:
except Exception:
b0 = float('inf')
time_f[config, "Triton"] = f
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')

if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
# xFormers CUTLASS
if "xformers.c" in methods and xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
Expand All @@ -143,9 +155,10 @@ def time_fwd_bwd(func, *args, **kwargs):
time_f[config, "xformers.c"] = f
time_b[config, "xformers.c"] = b

if xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)]
# xFormers Flash
if "xformers.f" in methods and xops is not None:
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,
device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
f, b = time_fwd_bwd(
xops.memory_efficient_attention, q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
Expand All @@ -154,8 +167,11 @@ def time_fwd_bwd(func, *args, **kwargs):
time_f[config, "xformers.f"] = f
time_b[config, "xformers.f"] = b

# Report
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
if (config, method) not in time_f or (config, method) not in time_b:
continue
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
Expand All @@ -175,6 +191,5 @@ def time_fwd_bwd(func, *args, **kwargs):
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
)


# with open('flash2_attn_time.plk', 'wb') as fp:
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)