diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index bd5ac56c3d..39859c5394 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1648,39 +1648,30 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): cumsum_s_qo = torch.sum(actual_seq_lens_q) cumsum_s_kv = torch.sum(actual_seq_lens_kv) - # Front-padding for cute-dsl varlen kernel: the persistent varlen kernel - # applies a negative pointer offset (-max_s * H * D), so there must be - # valid GPU memory before the data start. - front_pad_q = s_qo if "cute-dsl" in backends else 0 - front_pad_kv = s_kv if "cute-dsl" in backends else 0 - - q_full = torch.randn( - front_pad_q + cumsum_s_qo, + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype, ) - q = q_full[front_pad_q:] if args.verbose >= 2: print(f"[VVERBOSE] {q.shape = }") - k_full = torch.randn( - front_pad_kv + cumsum_s_kv, + k = torch.randn( + cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype, ) - k = k_full[front_pad_kv:] - v_full = torch.randn( - front_pad_kv + cumsum_s_kv, + v = torch.randn( + cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype, ) - v = v_full[front_pad_kv:] block_tables = None @@ -1839,17 +1830,13 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): trtllm_out = None if "trtllm-native" in backends or "cute-dsl" in backends: - # cute-dsl varlen kernel uses negative pointer offsets on output, - # so front-pad like Q/K/V. - out_pad = front_pad_q if "cute-dsl" in backends else 0 - trtllm_out_full = torch.empty( - out_pad + q.shape[0], + trtllm_out = torch.empty( + q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=out_dtype, ) - trtllm_out = trtllm_out_full[out_pad:] def run_backend_wrapper( backend, diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index e805734ab6..35c9ff6830 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -145,7 +145,7 @@ class ArtifactPath: CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" - DSL_FMHA: str = "c770c91cb0d991b7828fc85d2253a62f0d356b6c/fmha/cute-dsl/" + DSL_FMHA: str = "801e770219613fbf088bc074c414732b26cc550d/fmha/cute-dsl/" DSL_FMHA_ARCHS: tuple[str, ...] = ("sm_100a", "sm_103a", "sm_110a") @@ -170,14 +170,14 @@ class CheckSumHash: # NOT hashes of individual kernel .so files. DSL_FMHA_CHECKSUMS: dict[str, dict[str, str]] = { "x86_64": { - "sm_100a": "9533536698cdc256d897fffb3114de317076654ff8630ff283d850cc3dc96d86", - "sm_103a": "927e1954f1d45b0ee876f139084e4facdfcc87e86f4d30cb92d5c33698d4c2d6", - "sm_110a": "277b1dceaab2081e3def37cf997280a3f2c3ac515d22b80be141253c0278b8b5", + "sm_100a": "778738c3aa89872248fcfddd134b57ae516021471df992d4ba9b058ead546d56", + "sm_103a": "f57abef4c65968c99e93faa051d9b98cf789c82c805bd3a177fb3f2a426dac4f", + "sm_110a": "f2450d136221d7c355876140af860999fd5f5cdd16ffa4b06ff8b799c2106c29", }, "aarch64": { - "sm_100a": "b48ed0bcc9bad4afd33e0784c8c9eb9e13e782afe197816b1d0747b11759493e", - "sm_103a": "bace619a560f3ce52ad6ba105fffb8ea8629fe57885a90892c9e15a7122467e1", - "sm_110a": "d8369bcfa443bfd791cd014e3b030d378f00a975db8278eebd5b2fb529e3257d", + "sm_100a": "10af42097962a92cbc8942a65dedf87259fdb8684d26c4f8326dbfbe4e8ff566", + "sm_103a": "2418ee60ced8eec216af5a44682151173c1ed63d5296c92c185bc3bef92f91cd", + "sm_110a": "6807c536800fba3c9ff516f4cc0a7b12bd5570dd94ab04704c9bc7daf9d1e821", }, } map_checksums: dict[str, str] = { diff --git a/flashinfer/attention/cute_dsl/fmha.py b/flashinfer/attention/cute_dsl/fmha.py index 6da175b95a..fabca080e0 100644 --- a/flashinfer/attention/cute_dsl/fmha.py +++ b/flashinfer/attention/cute_dsl/fmha.py @@ -127,6 +127,8 @@ def _get_variant_name( varlen: bool = False, with_lse: bool = False, enable_skip_softmax: bool = False, + enable_sink: bool = False, + use_pdl: bool = False, enable_tvm_ffi: bool = False, ) -> str: """Generate the variant name matching compile_cute_dsl_fmha.py naming convention.""" @@ -139,8 +141,10 @@ def _get_variant_name( varlen_str = "_varlen" if varlen else "" lse_str = "_lse" if with_lse else "" skip_str = "_skipsm" if enable_skip_softmax else "" + sink_str = "_sink" if enable_sink else "" + pdl_str = "_pdl" if use_pdl else "" ffi_str = "_tvmffi" if enable_tvm_ffi else "" - return f"cute_dsl_fmha_{dtype_str}_h{head_dim}_{causal_str}_{persist_str}{varlen_str}{lse_str}{skip_str}{ffi_str}" + return f"cute_dsl_fmha_{dtype_str}_h{head_dim}_{causal_str}_{persist_str}{varlen_str}{lse_str}{skip_str}{sink_str}{pdl_str}{ffi_str}" # ============================================================================= @@ -234,6 +238,8 @@ def get_cute_dsl_fmha_kernel( varlen: bool = False, with_lse: bool = False, enable_skip_softmax: bool = False, + enable_sink: bool = False, + use_pdl: bool = False, ): """Get a compiled DSL FMHA kernel function. @@ -277,6 +283,8 @@ def get_cute_dsl_fmha_kernel( varlen, with_lse, enable_skip_softmax, + enable_sink, + use_pdl, enable_tvm_ffi, ) @@ -312,6 +320,7 @@ def cute_dsl_fmha_ragged_prefill( window_left: int = -1, window_right: int = -1, lse: Optional[torch.Tensor] = None, + attention_sinks: Optional[torch.Tensor] = None, scale_q: float = 1.0, scale_k: float = 1.0, scale_v: float = 1.0, @@ -321,37 +330,24 @@ def cute_dsl_fmha_ragged_prefill( max_kv_len: Optional[int] = None, kernel_fn=None, skip_softmax_threshold_scale_factor: Optional[float] = None, + enable_pdl: bool = False, ) -> None: """Run DSL FMHA prefill kernel on ragged (variable-length) tensors. Note: The DSL FMHA kernel only supports per-tensor scalar scales, not per-head scale tensors. - **Front-padding requirement** (TODO: will be removed in the next MR): - The DSL kernel applies a negative pointer offset - (``-max_seq_len * H * D`` elements) internally. Callers must - allocate ``max_seq_len + total_tokens`` rows and pass the slice starting - at ``[max_seq_len:]`` as q/k/v/o so that the preceding memory is valid - GPU memory. For example:: - - q_full = torch.empty(max_s_q + total_q, H_q, D, ...) - q = q_full[max_s_q:] # pass this to the kernel - # (same for k, v, o with max_s_k / max_s_q respectively) - Parameters ---------- q : torch.Tensor Query tensor, shape (total_q_tokens, H_q, D). - Must have ``max_qo_len`` rows of valid GPU memory before index 0. k : torch.Tensor Key tensor, shape (total_kv_tokens, H_k, D). - Must have ``max_kv_len`` rows of valid GPU memory before index 0. v : torch.Tensor Value tensor, shape (total_kv_tokens, H_k, D_v). - Must have ``max_kv_len`` rows of valid GPU memory before index 0. o : torch.Tensor Output tensor, shape (total_q_tokens, H_q, D_v). Modified in-place. - Must have ``max_qo_len`` rows of valid GPU memory before index 0. + Must be 32-byte aligned (kernel uses 256-bit store instructions). qo_indptr : torch.Tensor Cumulative sequence lengths for Q/O, shape (batch_size + 1,). Same as cum_seqlen_q in DSL FMHA kernel. @@ -368,6 +364,8 @@ def cute_dsl_fmha_ragged_prefill( Right sliding window size. -1 means no window. 0 for causal. lse : torch.Tensor, optional Log-sum-exp output tensor. None to skip. + attention_sinks : torch.Tensor, optional + Attention sink tensor, shape (H_q,) Float32. None to disable. scale_q : float Per-tensor scale for query (FP8 calibration). Default 1.0. scale_k : float @@ -413,6 +411,8 @@ def cute_dsl_fmha_ragged_prefill( enable_tvm_ffi=enable_tvm_ffi, with_lse=lse is not None, enable_skip_softmax=use_skip_softmax, + enable_sink=attention_sinks is not None, + use_pdl=enable_pdl, ) # Compute scale factors @@ -452,24 +452,32 @@ def cute_dsl_fmha_ragged_prefill( if is_causal and ws_right is None: ws_right = Int32(0) - if enable_tvm_ffi: - # TVM-FFI: Pointer args accept data_ptr(), Tensor args accept torch.Tensor, - # no explicit stream (env stream). - # Kernel expects 4D pointers; unsqueeze to (1, total, H, D). - q_4d = q.unsqueeze(0) - k_4d = k.unsqueeze(0) - v_4d = v.unsqueeze(0) - o_4d = o.unsqueeze(0) + # Reshape to 5D matching kernel docstring: + # q/o: (b=1, total, h_k, h_r, d/dv) + # k/v: (b=1, total, h_k, 1, d/dv) + h_r = H_q // H_k + q_5d = q.view(1, total_q, H_k, h_r, D) + k_5d = k.view(1, total_kv, H_k, 1, D) + v_5d = v.view(1, total_kv, H_k, 1, D_v) + assert o.data_ptr() % 32 == 0, ( + "o must be 32-byte aligned (kernel uses 256-bit store instructions)" + ) + o_5d = o.view(1, total_q, H_k, h_r, D_v) + # LSE: (1, total_q, h_k, h_r) — 4D row-major. + lse_4d = lse.view(1, total_q, H_k, h_r) if lse is not None else None + if enable_tvm_ffi: + # TVM-FFI: pass torch.Tensor directly, no explicit stream (env stream). kernel_fn( - q_4d.data_ptr(), - k_4d.data_ptr(), - v_4d.data_ptr(), - o_4d.data_ptr(), + q_5d, + k_5d, + v_5d, + o_5d, problem_size, - qo_indptr.to(torch.int32), # cum_seqlen_q: Tensor arg - kv_indptr.to(torch.int32), # cum_seqlen_k: Tensor arg - lse.data_ptr() if lse is not None else None, + qo_indptr.to(torch.int32), # cum_seqlen_q + kv_indptr.to(torch.int32), # cum_seqlen_k + lse_4d, + attention_sinks, Float32(scale_softmax_log2), Float32(scale_softmax), Float32(scale_output), @@ -478,50 +486,43 @@ def cute_dsl_fmha_ragged_prefill( ws_right, None, # skip_softmax_count None, # total_softmax_count - q_4d, # q_tensor for env stream device detection + enable_pdl, ) else: - # CuTe native ABI: convert to cute tensors, pass iterators + explicit stream. - - # DSL FMHA kernel expects 4D tensor (B, S, H, D). - q_4d = q.unsqueeze(0) - k_4d = k.unsqueeze(0) - v_4d = v.unsqueeze(0) - o_4d = o.unsqueeze(0) - + # CuTe native ABI: convert to cute tensors and pass with explicit stream. is_fp8_in = q.dtype == torch.float8_e4m3fn is_fp8_out = o.dtype == torch.float8_e4m3fn if is_fp8_in: q_cute = from_dlpack( - q_4d.view(torch.int8), assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) + q_5d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=4) q_cute.element_type = cutlass.Float8E4M3FN k_cute = from_dlpack( - k_4d.view(torch.int8), assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) + k_5d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=4) k_cute.element_type = cutlass.Float8E4M3FN v_cute = from_dlpack( - v_4d.view(torch.int8), assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) + v_5d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=4) v_cute.element_type = cutlass.Float8E4M3FN else: - q_cute = from_dlpack(q_4d, assumed_align=16).mark_layout_dynamic( - leading_dim=3 + q_cute = from_dlpack(q_5d, assumed_align=16).mark_layout_dynamic( + leading_dim=4 ) - k_cute = from_dlpack(k_4d, assumed_align=16).mark_layout_dynamic( - leading_dim=3 + k_cute = from_dlpack(k_5d, assumed_align=16).mark_layout_dynamic( + leading_dim=4 ) - v_cute = from_dlpack(v_4d, assumed_align=16).mark_layout_dynamic( - leading_dim=3 + v_cute = from_dlpack(v_5d, assumed_align=16).mark_layout_dynamic( + leading_dim=4 ) if is_fp8_out: o_cute = from_dlpack( - o_4d.view(torch.int8), assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) + o_5d.view(torch.int8), assumed_align=32 + ).mark_layout_dynamic(leading_dim=4) o_cute.element_type = cutlass.Float8E4M3FN else: - o_cute = from_dlpack(o_4d, assumed_align=16).mark_layout_dynamic( - leading_dim=3 + o_cute = from_dlpack(o_5d, assumed_align=32).mark_layout_dynamic( + leading_dim=4 ) cum_seqlen_q_cute = from_dlpack( @@ -531,25 +532,30 @@ def cute_dsl_fmha_ragged_prefill( kv_indptr.to(torch.int32), assumed_align=16 ).mark_layout_dynamic(leading_dim=0) - lse_iter = None - if lse is not None: - # TODO: lse's shape? - lse_cute = from_dlpack(lse, assumed_align=16).mark_layout_dynamic( - leading_dim=2 + lse_cute = None + if lse_4d is not None: + lse_cute = from_dlpack(lse_4d, assumed_align=16).mark_layout_dynamic( + leading_dim=3 ) - lse_iter = lse_cute.iterator + + sink_cute = None + if attention_sinks is not None: + sink_cute = from_dlpack( + attention_sinks, assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) stream = cuda_driver.CUstream(torch.cuda.current_stream().cuda_stream) kernel_fn( - q_cute.iterator, - k_cute.iterator, - v_cute.iterator, - o_cute.iterator, + q_cute, + k_cute, + v_cute, + o_cute, problem_size, cum_seqlen_q_cute, cum_seqlen_k_cute, - lse_iter, + lse_cute, + sink_cute, Float32(scale_softmax_log2), Float32(scale_softmax), Float32(scale_output), @@ -558,6 +564,6 @@ def cute_dsl_fmha_ragged_prefill( ws_right, None, # skip_softmax_count None, # total_softmax_count - None, # q_tensor (unused, for TVM-FFI env stream) stream, + enable_pdl, ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 8d5d726c14..d4b75f1e7a 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3861,10 +3861,6 @@ def trtllm_ragged_attention_deepseek( lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] backend : str Attention backend to use. "trtllm-gen" (default) or "cute-dsl". - When backend="cute-dsl", query/key/value/out tensors must be - front-padded with max_seq_len rows of valid GPU memory before - index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details). - This requirement will be removed in the next MR. Returns ------- @@ -3926,20 +3922,6 @@ def trtllm_ragged_attention_deepseek( if backend == "cute-dsl": from .attention.cute_dsl.fmha import cute_dsl_fmha_ragged_prefill - import warnings - - # TODO: remove this warning when PDL support added - # TODO: support PDL for cute-dsl backend - if enable_pdl: - warnings.warn( - "cute-dsl backend does not support PDL yet (enable_pdl ignored)", - stacklevel=2, - ) - if attention_sinks is not None: - warnings.warn( - "cute-dsl backend does not support attention_sinks (ignored)", - stacklevel=2, - ) _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float8_e4m3fn) assert query.dtype in _SUPPORTED_DTYPES, ( f"cute-dsl backend only supports {_SUPPORTED_DTYPES}, got {query.dtype}" @@ -3971,6 +3953,7 @@ def trtllm_ragged_attention_deepseek( sm_scale=_bmm1, window_left=window_left, lse=lse if return_lse else None, + attention_sinks=attention_sinks, scale_q=1.0, scale_k=1.0, scale_v=_bmm2, @@ -3978,6 +3961,7 @@ def trtllm_ragged_attention_deepseek( max_qo_len=max_q_len, max_kv_len=max_kv_len, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + enable_pdl=enable_pdl, ) else: # --- trtllm-gen backend --- diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index c2bedad1aa..a33491d123 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1853,6 +1853,7 @@ def test_trtllm_batch_decode_head_dim_512( @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("enable_sink", [False, True]) def test_trtllm_gen_prefill( backend: str, mla_dimensions: MLAHeadDimensions, @@ -1863,6 +1864,7 @@ def test_trtllm_gen_prefill( head_grp_size: int, causal: bool, skips_softmax: bool, + enable_sink: bool, ) -> None: compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -1893,47 +1895,19 @@ def test_trtllm_gen_prefill( cumsum_s_qo = int(torch.sum(actual_seq_lens_q).item()) cumsum_s_kv = int(torch.sum(actual_seq_lens_kv).item()) - # DSL FMHA varlen kernel uses negative pointer offsets, so tensors need - # front-padding of max_s elements to ensure valid GPU memory before data. - if backend == "cute-dsl": - q_full = torch.randn( - s_qo + cumsum_s_qo, - num_qo_heads, - head_dim_qk, - device=device, - dtype=torch.bfloat16, - ) - q = q_full[s_qo:] - k_full = torch.randn( - s_kv + cumsum_s_kv, - num_kv_heads, - head_dim_qk, - device=device, - dtype=torch.bfloat16, - ) - k_cache = k_full[s_kv:] - v_full = torch.randn( - s_kv + cumsum_s_kv, - num_kv_heads, - head_dim_vo, - device=device, - dtype=torch.bfloat16, - ) - v_cache = v_full[s_kv:] - else: - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 - ) - k_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_qk), - device=device, - dtype=torch.bfloat16, - ) - v_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_vo), - device=device, - dtype=torch.bfloat16, - ) + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 + ) + k_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_qk), + device=device, + dtype=torch.bfloat16, + ) + v_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_vo), + device=device, + dtype=torch.bfloat16, + ) # Initialize scale scale = float(1.0 / (head_dim_qk**0.5)) @@ -1960,35 +1934,46 @@ def test_trtllm_gen_prefill( ] ).int() - wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer_ref, - kv_layout="NHD", - backend="cutlass", - ) - wrapper.plan( - qo_indptr, - kv_indptr, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo=head_dim_vo, - causal=causal, - sm_scale=scale, - q_data_type=torch.bfloat16, - kv_data_type=torch.bfloat16, + sink = ( + torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5 + if enable_sink + else None ) - output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - if backend == "cute-dsl": - output_full = torch.empty( - s_qo + cumsum_s_qo, + lse_ref = None + if enable_sink: + output_ref = sink_attention_unified( + q, + k_cache, + v_cache, + sink, + window_left=-1, + causal=causal, + sm_scale=scale, + mode="varlen", + batch_size=batch_size, + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + ).to(torch.bfloat16) + else: + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer_ref, + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, num_qo_heads, - head_dim_vo, - device=device, - dtype=output_ref.dtype, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=scale, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, ) - output = output_full[s_qo:] - else: - output = torch.empty_like(output_ref) + output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) + output = torch.empty_like(output_ref) bmm1_scale = scale bmm2_scale = 1.0 @@ -2014,6 +1999,7 @@ def test_trtllm_gen_prefill( False, causal, True, + attention_sinks=sink, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, out=output, backend=backend, @@ -2024,12 +2010,13 @@ def test_trtllm_gen_prefill( atol=1e-2, rtol=1e-2, ) - torch.testing.assert_close( - lse_trtllm, - lse_ref, - atol=1e-3, - rtol=1e-3, - ) + if lse_ref is not None: + torch.testing.assert_close( + lse_trtllm, + lse_ref, + atol=1e-3, + rtol=1e-3, + ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future if backend == "trtllm-native": @@ -2083,10 +2070,10 @@ def test_trtllm_gen_prefill_fp8( # FP8 scales scale_q, scale_k, scale_v = 0.05, 0.04, 0.06 - # Generate in float32, quantize to FP8 with front-padding + # Generate in float32, quantize to FP8 q_f32 = ( torch.randn( - s_qo + cumsum_s_qo, + cumsum_s_qo, num_qo_heads, head_dim_qk, dtype=torch.float32, @@ -2096,7 +2083,7 @@ def test_trtllm_gen_prefill_fp8( ) k_f32 = ( torch.randn( - s_kv + cumsum_s_kv, + cumsum_s_kv, num_kv_heads, head_dim_qk, dtype=torch.float32, @@ -2106,7 +2093,7 @@ def test_trtllm_gen_prefill_fp8( ) v_f32 = ( torch.randn( - s_kv + cumsum_s_kv, + cumsum_s_kv, num_kv_heads, head_dim_vo, dtype=torch.float32, @@ -2115,9 +2102,9 @@ def test_trtllm_gen_prefill_fp8( * 0.1 ) - q = (q_f32 / scale_q).to(torch.float8_e4m3fn)[s_qo:] - k_cache = (k_f32 / scale_k).to(torch.float8_e4m3fn)[s_kv:] - v_cache = (v_f32 / scale_v).to(torch.float8_e4m3fn)[s_kv:] + q = (q_f32 / scale_q).to(torch.float8_e4m3fn) + k_cache = (k_f32 / scale_k).to(torch.float8_e4m3fn) + v_cache = (v_f32 / scale_v).to(torch.float8_e4m3fn) # Reference: dequantize and run bf16 attention q_bf16 = (q.float() * scale_q).to(torch.bfloat16) @@ -2159,15 +2146,7 @@ def test_trtllm_gen_prefill_fp8( ) output_ref, _ = wrapper.run(q_bf16, k_bf16, v_bf16, return_lse=True) - # Output with front-padding - output_full = torch.empty( - s_qo + cumsum_s_qo, - num_qo_heads, - head_dim_vo, - device=device, - dtype=torch.bfloat16, - ) - output = output_full[s_qo:] + output = torch.empty_like(output_ref) scale = 1.0 / (head_dim_qk**0.5) bmm1_scale = scale_q * scale_k * scale @@ -2234,6 +2213,7 @@ def test_trtllm_gen_prefill_bs1( head_grp_size, causal, skips_softmax, + enable_sink=False, )