Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hopper/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def run(*args, **kwargs):
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
else:
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
if dtype != torch.float8_e4m3fn and headdim == headdim_v:
Expand All @@ -364,7 +364,7 @@ def run(*args, **kwargs):
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav3')
else:
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav3')
time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
# time.sleep(1)
Expand Down
4 changes: 2 additions & 2 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ def flash_attn_varlen_func(
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
Expand Down
3 changes: 2 additions & 1 deletion hopper/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q, seqused_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
causal=causal,
qv=qv_unpad,
q_descale=q_descale,
Expand Down