diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5d1f5369214..36f0bf6d036 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -355,7 +355,7 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: - m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if dtype != torch.float8_e4m3fn and headdim == headdim_v: @@ -364,7 +364,7 @@ def run(*args, **kwargs): _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: - _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean # time.sleep(1) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index adee1a0ff26..78cfe1cb906 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -563,10 +563,10 @@ def flash_attn_varlen_func( v, cu_seqlens_q, cu_seqlens_k, - seqused_q, - seqused_k, max_seqlen_q, max_seqlen_k, + seqused_q=None, + seqused_k=None, softmax_scale=None, causal=False, qv=None, diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index e9cd8c9d6cb..ddd687f1fe8 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -450,9 +450,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): v_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, causal=causal, qv=qv_unpad, q_descale=q_descale,