Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,14 @@ def search_cache(
Returns:
A tuple containing:
[is_cache_hit, runner_id, tactic, stored_profile]
runner_id is the index in the current runners list
"""
for r in runners:
for idx, r in enumerate(runners):
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
tuning_config)) in self.cache:
return True, *self.cache[cache_key]
# Return the current index in runners list, not the cached runner_id
cached_runner_id, tactic, min_time = self.cache[cache_key]
return True, idx, tactic, min_time

return False, *self.fallback_entry()

Expand Down
44 changes: 37 additions & 7 deletions tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,24 @@ def target_scaled_mm_prologue_pattern(
)

def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
act_fp4_key = KeywordArg('act_fp4')
weight_key = KeywordArg('weight')
act_sf_key = KeywordArg('act_sf')
weight_scale_key = KeywordArg('weight_scale')
alpha_key = KeywordArg('alpha')
output_dtype_key = KeywordArg('output_dtype')
to_userbuffers_key = KeywordArg('to_userbuffers')
backend_key = KeywordArg('backend')
trtllm_nvfp4_gemm_default = CallFunction(
torch.ops.trtllm.nvfp4_gemm.default, KeywordArg('act_fp4'),
KeywordArg('weight'), KeywordArg('act_sf'),
KeywordArg('weight_scale'), KeywordArg('alpha'),
KeywordArg('output_dtype'))
torch.ops.trtllm.nvfp4_gemm.default,
act_fp4_key,
weight_key,
act_sf_key,
weight_scale_key,
alpha_key,
output_dtype_key,
to_userbuffers=to_userbuffers_key,
backend=backend_key)
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
trtllm_nvfp4_gemm_default)

Expand All @@ -569,6 +582,8 @@ def empty_nvfp4_gemm_prologue_pattern(
weight_scale: torch.Tensor,
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool,
backend: str,
):
return

Expand All @@ -579,21 +594,36 @@ def target_nvfp4_gemm_prologue_pattern(
weight_scale: torch.Tensor,
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool,
backend: str,
):
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
True)
True, backend)
return nvfp4_gemm_output

# No extra check needed as the output dtype of nvfp4_gemm has been verified when
# ub_copy is inserted.
def extra_check(match: Match) -> bool:
# Validate backend value
backend_value = match.kwargs.get('backend')
if backend_value is None:
# No backend specified, use default - OK
return True

# backend should be a string literal
if not isinstance(backend_value, str):
return False

valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'}
return backend_value in valid_backends

register_replacement(
empty_nvfp4_gemm_prologue_pattern,
target_nvfp4_gemm_prologue_pattern,
[],
fwd_only,
custom_pass,
search_fn_pattern=ub_copy,
extra_check=extra_check,
)

def register_mm_prologue(custom_pass: PatternMatcherPass):
Expand Down
159 changes: 134 additions & 25 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from tensorrt_llm.logger import logger

from ..._utils import get_sm_version
from ...math_utils import pad_up
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
Expand Down Expand Up @@ -32,7 +34,7 @@
Sm100BlockScaledPersistentDenseGemmKernel
from ..cute_dsl_kernels.blackwell.utils import make_ptr

class CuteDSLNVFP4BlackwellRunner(TunableRunner):
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
kernel_class = Sm100BlockScaledPersistentDenseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
Expand All @@ -43,26 +45,44 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner):
use_cold_l2_cache=True,
)

def __init__(self, alpha: float, output_dtype: torch.dtype):
def __init__(self,
output_dtype: torch.dtype,
to_userbuffers: bool = False):
super().__init__()
self.alpha = alpha
self.output_dtype = output_dtype
assert output_dtype == torch.bfloat16

if get_sm_version() not in [100, 103]:
if output_dtype != torch.bfloat16:
raise ValueError(
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103"
f"CuteDSL NVFP4 only supports bfloat16 output, got {output_dtype}"
)
self.output_dtype = output_dtype
self.to_userbuffers = to_userbuffers

def unique_id(self):
return (self.output_dtype, )
return (self.output_dtype, self.to_userbuffers)

def __hash__(self):
return hash((self.output_dtype, self.to_userbuffers))

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[Tuple[int, int]]:
# Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
sm_version = get_sm_version()
if sm_version not in [100, 103]:
logger.debug(
f"CuteDSL: SM version {sm_version} is not supported. "
f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
)
return []

assert inputs[0].dim() == 2
assert inputs[1].dim() == 2

Expand All @@ -73,11 +93,44 @@ def get_valid_tactics(
real_k = k * 2
batch_size = 1
sf_vec_size = 16
# m,k

# Fixed layout for FP4: A and B are always K-major
a_major = "k"
# n, k
b_major = "k"

# Early exit: Check K dimension alignment
# For K-major layout (A and B tensors), K is the major mode (contiguous dimension).
# 16-byte alignment requirement: K must be divisible by 32 for FP4 (128 bits / 4 bits = 32)
if real_k % 32 != 0:
logger.debug(
f"CuteDSL: K={real_k} does not meet 16-byte alignment requirement "
f"(K%32={real_k%32}, expected 0). Skipping all tactics.")
return []

# Optimize swap_ab candidates based on M and N alignment
# swap_ab=False → C is N-major → requires N%8==0 (BF16: 128 bits / 16 bits = 8)
# swap_ab=True → C is M-major → requires M%8==0
m_aligned = (m % 8 == 0)
n_aligned = (n % 8 == 0)

if not m_aligned and not n_aligned:
logger.debug(
f"CuteDSL: Neither M={m} nor N={n} meets 16-byte alignment "
f"(M%8={m%8}, N%8={n%8}). No valid C layout. Skipping all tactics."
)
return []

# Only test swap_ab values that satisfy alignment
swap_ab_candidates = []
if n_aligned:
swap_ab_candidates.append(False) # N-major layout
if m_aligned:
swap_ab_candidates.append(True) # M-major layout

logger.debug(
f"CuteDSL: M={m}(aligned={m_aligned}), N={n}(aligned={n_aligned}), K={real_k}(aligned=True). "
f"Testing swap_ab={swap_ab_candidates}")

# full shamoo
mma_tiler_mn_candidates = [
(128, 64),
Expand Down Expand Up @@ -134,6 +187,9 @@ def get_valid_tactics(
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch))

logger.debug(
f"CuteDSL: Found {len(valid_tactics)} valid tactics for M={m}, N={n}, K={real_k}"
)
return valid_tactics

def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
Expand All @@ -149,6 +205,7 @@ def forward(
self,
inputs: List[torch.Tensor],
tactic,
**kwargs,
) -> torch.Tensor:
"""
Performs fp8 blockwise gemm operation using CuTe DSL.
Expand All @@ -160,8 +217,7 @@ def forward(
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
inputs[4]: Alpha scaling factor. dtype: float32.
inputs[5]: Output dtype, expected to be torch.bfloat16.
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch).
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).

Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
Expand All @@ -179,11 +235,17 @@ def forward(
False,
]

a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, alpha_tensor = inputs
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
c_tensor = torch.empty(*(m, n),
dtype=self.output_dtype,
device="cuda")

# Allocate output tensor from UserBuffers or regular CUDA memory
if self.to_userbuffers:
c_tensor = torch.ops.trtllm.create_userbuffers_tensor(
[m, n], self.output_dtype)
else:
c_tensor = torch.empty(*(m, n),
dtype=self.output_dtype,
device="cuda")

if swap_ab:
c_tensor = c_tensor.permute(1, 0)
Expand All @@ -193,9 +255,27 @@ def forward(
sf_k = pad_up(real_k // sf_vec_size, 4)
sf_n = pad_up(n, 128)

# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
assert a_sf_tensor.shape == (sf_m * sf_k, )
assert b_sf_tensor.shape == (sf_n * sf_k, )
# Reshape scale factors to CuteDSL's expected format
# Input format (from CUTLASS/cuBLASLt): (m*k//16,) and (n*k//16,)
# CuteDSL format: (sf_m*sf_k,) and (sf_n*sf_k,)
# Note: This is just a view change, no memory copy
expected_a_sf_size = sf_m * sf_k
expected_b_sf_size = sf_n * sf_k

if a_sf_tensor.numel() != expected_a_sf_size:
raise ValueError(
f"CuteDSL: act scale factor size mismatch. "
f"Expected {expected_a_sf_size} (sf_m={sf_m} * sf_k={sf_k}), "
f"got {a_sf_tensor.numel()} for shape M={m}, K={real_k}")
if b_sf_tensor.numel() != expected_b_sf_size:
raise ValueError(
f"CuteDSL: weight scale factor size mismatch. "
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")

# Reshape to CuteDSL's expected format (just a view, no copy)
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)

a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
cutlass.Float4E2M1FN, 32)
Expand All @@ -207,6 +287,9 @@ def forward(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
cutlass.BFloat16, 16)
# Create pointer to alpha on device
alpha_ptr = self.make_cute_dsl_global_pointer(
alpha_tensor, cutlass.Float32, 4)

# get stream
torch_stream = torch.cuda.current_stream()
Expand Down Expand Up @@ -259,7 +342,7 @@ def forward(
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
self.alpha,
alpha_ptr, # Pass alpha as device pointer
max_active_clusters,
stream,
swap_ab,
Expand All @@ -283,7 +366,7 @@ def forward(
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
self.alpha,
alpha_ptr, # Pass alpha as device pointer
stream,
)

Expand All @@ -300,20 +383,45 @@ def cute_dsl_nvfp4_gemm_blackwell(
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
) -> torch.Tensor:
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.

Args:
input: Activation tensor [m, k] in FP4 format (packed in uint8)
weight: Weight tensor [n, k] in FP4 format (packed in uint8)
input_scale: Activation scale factors
weight_scale: Weight scale factors
alpha: Scaling factor
output_dtype: Output data type (must be bfloat16)
to_userbuffers: Whether to allocate output from UserBuffers pool

Note:
This function is primarily used internally by nvfp4_gemm.
Direct usage is discouraged. Consider using nvfp4_gemm instead
for automatic backend selection with better performance.
"""
# Validate SM version before attempting to use CuteDSL
sm_version = get_sm_version()
if sm_version not in [100, 103]:
raise ValueError(
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
)

tuner = AutoTuner.get()

runner = CuteDSLNVFP4BlackwellRunner(alpha, output_dtype)
inputs = [input, weight, input_scale, weight_scale]
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers)
inputs = [input, weight, input_scale, weight_scale, alpha]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
[runner],
runner.__class__.tuning_config,
inputs,
)

output = runner(inputs, tactic=best_tactic)
return output

Expand All @@ -323,8 +431,9 @@ def _(
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
alpha: torch.Tensor, # Match custom op signature
output_dtype: torch.dtype,
to_userbuffers: bool = False,
):
# [m, k]
shape = list(mat_a.shape)
Expand Down
Loading