Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,39 +1648,30 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
cumsum_s_qo = torch.sum(actual_seq_lens_q)
cumsum_s_kv = torch.sum(actual_seq_lens_kv)

# Front-padding for cute-dsl varlen kernel: the persistent varlen kernel
# applies a negative pointer offset (-max_s * H * D), so there must be
# valid GPU memory before the data start.
front_pad_q = s_qo if "cute-dsl" in backends else 0
front_pad_kv = s_kv if "cute-dsl" in backends else 0

q_full = torch.randn(
front_pad_q + cumsum_s_qo,
q = torch.randn(
cumsum_s_qo,
num_qo_heads,
head_dim_qk,
device=device,
dtype=q_init_dtype,
)
q = q_full[front_pad_q:]
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")

k_full = torch.randn(
front_pad_kv + cumsum_s_kv,
k = torch.randn(
cumsum_s_kv,
num_kv_heads,
head_dim_qk,
device=device,
dtype=kv_init_dtype,
)
k = k_full[front_pad_kv:]
v_full = torch.randn(
front_pad_kv + cumsum_s_kv,
v = torch.randn(
cumsum_s_kv,
num_kv_heads,
head_dim_vo,
device=device,
dtype=kv_init_dtype,
)
v = v_full[front_pad_kv:]

block_tables = None

Expand Down Expand Up @@ -1839,17 +1830,13 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):

trtllm_out = None
if "trtllm-native" in backends or "cute-dsl" in backends:
# cute-dsl varlen kernel uses negative pointer offsets on output,
# so front-pad like Q/K/V.
out_pad = front_pad_q if "cute-dsl" in backends else 0
trtllm_out_full = torch.empty(
out_pad + q.shape[0],
trtllm_out = torch.empty(
q.shape[0],
q.shape[1],
v.shape[2],
device=q.device,
dtype=out_dtype,
)
trtllm_out = trtllm_out_full[out_pad:]

def run_backend_wrapper(
backend,
Expand Down
14 changes: 7 additions & 7 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ArtifactPath:
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
# For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
DSL_FMHA: str = "c770c91cb0d991b7828fc85d2253a62f0d356b6c/fmha/cute-dsl/"
DSL_FMHA: str = "801e770219613fbf088bc074c414732b26cc550d/fmha/cute-dsl/"
DSL_FMHA_ARCHS: tuple[str, ...] = ("sm_100a", "sm_103a", "sm_110a")


Expand All @@ -170,14 +170,14 @@ class CheckSumHash:
# NOT hashes of individual kernel .so files.
DSL_FMHA_CHECKSUMS: dict[str, dict[str, str]] = {
"x86_64": {
"sm_100a": "9533536698cdc256d897fffb3114de317076654ff8630ff283d850cc3dc96d86",
"sm_103a": "927e1954f1d45b0ee876f139084e4facdfcc87e86f4d30cb92d5c33698d4c2d6",
"sm_110a": "277b1dceaab2081e3def37cf997280a3f2c3ac515d22b80be141253c0278b8b5",
"sm_100a": "778738c3aa89872248fcfddd134b57ae516021471df992d4ba9b058ead546d56",
"sm_103a": "f57abef4c65968c99e93faa051d9b98cf789c82c805bd3a177fb3f2a426dac4f",
"sm_110a": "f2450d136221d7c355876140af860999fd5f5cdd16ffa4b06ff8b799c2106c29",
},
"aarch64": {
"sm_100a": "b48ed0bcc9bad4afd33e0784c8c9eb9e13e782afe197816b1d0747b11759493e",
"sm_103a": "bace619a560f3ce52ad6ba105fffb8ea8629fe57885a90892c9e15a7122467e1",
"sm_110a": "d8369bcfa443bfd791cd014e3b030d378f00a975db8278eebd5b2fb529e3257d",
"sm_100a": "10af42097962a92cbc8942a65dedf87259fdb8684d26c4f8326dbfbe4e8ff566",
"sm_103a": "2418ee60ced8eec216af5a44682151173c1ed63d5296c92c185bc3bef92f91cd",
"sm_110a": "6807c536800fba3c9ff516f4cc0a7b12bd5570dd94ab04704c9bc7daf9d1e821",
},
}
map_checksums: dict[str, str] = {
Expand Down
142 changes: 74 additions & 68 deletions flashinfer/attention/cute_dsl/fmha.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _get_variant_name(
varlen: bool = False,
with_lse: bool = False,
enable_skip_softmax: bool = False,
enable_sink: bool = False,
use_pdl: bool = False,
enable_tvm_ffi: bool = False,
) -> str:
"""Generate the variant name matching compile_cute_dsl_fmha.py naming convention."""
Expand All @@ -139,8 +141,10 @@ def _get_variant_name(
varlen_str = "_varlen" if varlen else ""
lse_str = "_lse" if with_lse else ""
skip_str = "_skipsm" if enable_skip_softmax else ""
sink_str = "_sink" if enable_sink else ""
pdl_str = "_pdl" if use_pdl else ""
ffi_str = "_tvmffi" if enable_tvm_ffi else ""
return f"cute_dsl_fmha_{dtype_str}_h{head_dim}_{causal_str}_{persist_str}{varlen_str}{lse_str}{skip_str}{ffi_str}"
return f"cute_dsl_fmha_{dtype_str}_h{head_dim}_{causal_str}_{persist_str}{varlen_str}{lse_str}{skip_str}{sink_str}{pdl_str}{ffi_str}"


# =============================================================================
Expand Down Expand Up @@ -234,6 +238,8 @@ def get_cute_dsl_fmha_kernel(
varlen: bool = False,
with_lse: bool = False,
enable_skip_softmax: bool = False,
enable_sink: bool = False,
use_pdl: bool = False,
):
"""Get a compiled DSL FMHA kernel function.

Expand Down Expand Up @@ -277,6 +283,8 @@ def get_cute_dsl_fmha_kernel(
varlen,
with_lse,
enable_skip_softmax,
enable_sink,
use_pdl,
enable_tvm_ffi,
)

Expand Down Expand Up @@ -312,6 +320,7 @@ def cute_dsl_fmha_ragged_prefill(
window_left: int = -1,
window_right: int = -1,
lse: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
scale_q: float = 1.0,
scale_k: float = 1.0,
scale_v: float = 1.0,
Expand All @@ -321,37 +330,24 @@ def cute_dsl_fmha_ragged_prefill(
max_kv_len: Optional[int] = None,
kernel_fn=None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool = False,
) -> None:
"""Run DSL FMHA prefill kernel on ragged (variable-length) tensors.

Note: The DSL FMHA kernel only supports per-tensor scalar scales, not
per-head scale tensors.

**Front-padding requirement** (TODO: will be removed in the next MR):
The DSL kernel applies a negative pointer offset
(``-max_seq_len * H * D`` elements) internally. Callers must
allocate ``max_seq_len + total_tokens`` rows and pass the slice starting
at ``[max_seq_len:]`` as q/k/v/o so that the preceding memory is valid
GPU memory. For example::

q_full = torch.empty(max_s_q + total_q, H_q, D, ...)
q = q_full[max_s_q:] # pass this to the kernel
# (same for k, v, o with max_s_k / max_s_q respectively)

Parameters
----------
q : torch.Tensor
Query tensor, shape (total_q_tokens, H_q, D).
Must have ``max_qo_len`` rows of valid GPU memory before index 0.
k : torch.Tensor
Key tensor, shape (total_kv_tokens, H_k, D).
Must have ``max_kv_len`` rows of valid GPU memory before index 0.
v : torch.Tensor
Value tensor, shape (total_kv_tokens, H_k, D_v).
Must have ``max_kv_len`` rows of valid GPU memory before index 0.
o : torch.Tensor
Output tensor, shape (total_q_tokens, H_q, D_v). Modified in-place.
Must have ``max_qo_len`` rows of valid GPU memory before index 0.
Must be 32-byte aligned (kernel uses 256-bit store instructions).
qo_indptr : torch.Tensor
Cumulative sequence lengths for Q/O, shape (batch_size + 1,).
Same as cum_seqlen_q in DSL FMHA kernel.
Expand All @@ -368,6 +364,8 @@ def cute_dsl_fmha_ragged_prefill(
Right sliding window size. -1 means no window. 0 for causal.
lse : torch.Tensor, optional
Log-sum-exp output tensor. None to skip.
attention_sinks : torch.Tensor, optional
Attention sink tensor, shape (H_q,) Float32. None to disable.
scale_q : float
Per-tensor scale for query (FP8 calibration). Default 1.0.
scale_k : float
Expand Down Expand Up @@ -413,6 +411,8 @@ def cute_dsl_fmha_ragged_prefill(
enable_tvm_ffi=enable_tvm_ffi,
with_lse=lse is not None,
enable_skip_softmax=use_skip_softmax,
enable_sink=attention_sinks is not None,
use_pdl=enable_pdl,
)

# Compute scale factors
Expand Down Expand Up @@ -452,24 +452,32 @@ def cute_dsl_fmha_ragged_prefill(
if is_causal and ws_right is None:
ws_right = Int32(0)

if enable_tvm_ffi:
# TVM-FFI: Pointer args accept data_ptr(), Tensor args accept torch.Tensor,
# no explicit stream (env stream).
# Kernel expects 4D pointers; unsqueeze to (1, total, H, D).
q_4d = q.unsqueeze(0)
k_4d = k.unsqueeze(0)
v_4d = v.unsqueeze(0)
o_4d = o.unsqueeze(0)
# Reshape to 5D matching kernel docstring:
# q/o: (b=1, total, h_k, h_r, d/dv)
# k/v: (b=1, total, h_k, 1, d/dv)
h_r = H_q // H_k
q_5d = q.view(1, total_q, H_k, h_r, D)
k_5d = k.view(1, total_kv, H_k, 1, D)
v_5d = v.view(1, total_kv, H_k, 1, D_v)
assert o.data_ptr() % 32 == 0, (
"o must be 32-byte aligned (kernel uses 256-bit store instructions)"
)
o_5d = o.view(1, total_q, H_k, h_r, D_v)
# LSE: (1, total_q, h_k, h_r) — 4D row-major.
lse_4d = lse.view(1, total_q, H_k, h_r) if lse is not None else None

if enable_tvm_ffi:
# TVM-FFI: pass torch.Tensor directly, no explicit stream (env stream).
kernel_fn(
q_4d.data_ptr(),
k_4d.data_ptr(),
v_4d.data_ptr(),
o_4d.data_ptr(),
q_5d,
k_5d,
v_5d,
o_5d,
problem_size,
qo_indptr.to(torch.int32), # cum_seqlen_q: Tensor arg
kv_indptr.to(torch.int32), # cum_seqlen_k: Tensor arg
lse.data_ptr() if lse is not None else None,
qo_indptr.to(torch.int32), # cum_seqlen_q
kv_indptr.to(torch.int32), # cum_seqlen_k
lse_4d,
attention_sinks,
Float32(scale_softmax_log2),
Float32(scale_softmax),
Float32(scale_output),
Expand All @@ -478,50 +486,43 @@ def cute_dsl_fmha_ragged_prefill(
ws_right,
None, # skip_softmax_count
None, # total_softmax_count
q_4d, # q_tensor for env stream device detection
enable_pdl,
)
else:
# CuTe native ABI: convert to cute tensors, pass iterators + explicit stream.

# DSL FMHA kernel expects 4D tensor (B, S, H, D).
q_4d = q.unsqueeze(0)
k_4d = k.unsqueeze(0)
v_4d = v.unsqueeze(0)
o_4d = o.unsqueeze(0)

# CuTe native ABI: convert to cute tensors and pass with explicit stream.
is_fp8_in = q.dtype == torch.float8_e4m3fn
is_fp8_out = o.dtype == torch.float8_e4m3fn
if is_fp8_in:
q_cute = from_dlpack(
q_4d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=3)
q_5d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=4)
q_cute.element_type = cutlass.Float8E4M3FN
k_cute = from_dlpack(
k_4d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=3)
k_5d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=4)
k_cute.element_type = cutlass.Float8E4M3FN
v_cute = from_dlpack(
v_4d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=3)
v_5d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=4)
v_cute.element_type = cutlass.Float8E4M3FN
else:
q_cute = from_dlpack(q_4d, assumed_align=16).mark_layout_dynamic(
leading_dim=3
q_cute = from_dlpack(q_5d, assumed_align=16).mark_layout_dynamic(
leading_dim=4
)
k_cute = from_dlpack(k_4d, assumed_align=16).mark_layout_dynamic(
leading_dim=3
k_cute = from_dlpack(k_5d, assumed_align=16).mark_layout_dynamic(
leading_dim=4
)
v_cute = from_dlpack(v_4d, assumed_align=16).mark_layout_dynamic(
leading_dim=3
v_cute = from_dlpack(v_5d, assumed_align=16).mark_layout_dynamic(
leading_dim=4
)
if is_fp8_out:
o_cute = from_dlpack(
o_4d.view(torch.int8), assumed_align=16
).mark_layout_dynamic(leading_dim=3)
o_5d.view(torch.int8), assumed_align=32
).mark_layout_dynamic(leading_dim=4)
o_cute.element_type = cutlass.Float8E4M3FN
else:
o_cute = from_dlpack(o_4d, assumed_align=16).mark_layout_dynamic(
leading_dim=3
o_cute = from_dlpack(o_5d, assumed_align=32).mark_layout_dynamic(
leading_dim=4
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

cum_seqlen_q_cute = from_dlpack(
Expand All @@ -531,25 +532,30 @@ def cute_dsl_fmha_ragged_prefill(
kv_indptr.to(torch.int32), assumed_align=16
).mark_layout_dynamic(leading_dim=0)

lse_iter = None
if lse is not None:
# TODO: lse's shape?
lse_cute = from_dlpack(lse, assumed_align=16).mark_layout_dynamic(
leading_dim=2
lse_cute = None
if lse_4d is not None:
lse_cute = from_dlpack(lse_4d, assumed_align=16).mark_layout_dynamic(
leading_dim=3
)
lse_iter = lse_cute.iterator

sink_cute = None
if attention_sinks is not None:
sink_cute = from_dlpack(
attention_sinks, assumed_align=16
).mark_layout_dynamic(leading_dim=0)

stream = cuda_driver.CUstream(torch.cuda.current_stream().cuda_stream)

kernel_fn(
q_cute.iterator,
k_cute.iterator,
v_cute.iterator,
o_cute.iterator,
q_cute,
k_cute,
v_cute,
o_cute,
problem_size,
cum_seqlen_q_cute,
cum_seqlen_k_cute,
lse_iter,
lse_cute,
sink_cute,
Float32(scale_softmax_log2),
Float32(scale_softmax),
Float32(scale_output),
Expand All @@ -558,6 +564,6 @@ def cute_dsl_fmha_ragged_prefill(
ws_right,
None, # skip_softmax_count
None, # total_softmax_count
None, # q_tensor (unused, for TVM-FFI env stream)
stream,
enable_pdl,
)
Loading
Loading