diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index ee1408b5bfd..54506678651 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -365,11 +365,14 @@ def search_cache( Returns: A tuple containing: [is_cache_hit, runner_id, tactic, stored_profile] + runner_id is the index in the current runners list """ - for r in runners: + for idx, r in enumerate(runners): if (cache_key := self.get_cache_key(custom_op, r, input_shapes, tuning_config)) in self.cache: - return True, *self.cache[cache_key] + # Return the current index in runners list, not the cached runner_id + cached_runner_id, tactic, min_time = self.cache[cache_key] + return True, idx, tactic, min_time return False, *self.fallback_entry() diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index afbaa0949df..bc42e761494 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -554,11 +554,24 @@ def target_scaled_mm_prologue_pattern( ) def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass): + act_fp4_key = KeywordArg('act_fp4') + weight_key = KeywordArg('weight') + act_sf_key = KeywordArg('act_sf') + weight_scale_key = KeywordArg('weight_scale') + alpha_key = KeywordArg('alpha') + output_dtype_key = KeywordArg('output_dtype') + to_userbuffers_key = KeywordArg('to_userbuffers') + backend_key = KeywordArg('backend') trtllm_nvfp4_gemm_default = CallFunction( - torch.ops.trtllm.nvfp4_gemm.default, KeywordArg('act_fp4'), - KeywordArg('weight'), KeywordArg('act_sf'), - KeywordArg('weight_scale'), KeywordArg('alpha'), - KeywordArg('output_dtype')) + torch.ops.trtllm.nvfp4_gemm.default, + act_fp4_key, + weight_key, + act_sf_key, + weight_scale_key, + alpha_key, + output_dtype_key, + to_userbuffers=to_userbuffers_key, + backend=backend_key) ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers, trtllm_nvfp4_gemm_default) @@ -569,6 +582,8 @@ def empty_nvfp4_gemm_prologue_pattern( weight_scale: torch.Tensor, alpha: torch.Tensor, output_dtype: torch.dtype, + to_userbuffers: bool, + backend: str, ): return @@ -579,14 +594,28 @@ def target_nvfp4_gemm_prologue_pattern( weight_scale: torch.Tensor, alpha: torch.Tensor, output_dtype: torch.dtype, + to_userbuffers: bool, + backend: str, ): nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm( act_fp4, weight, act_sf, weight_scale, alpha, output_dtype, - True) + True, backend) return nvfp4_gemm_output - # No extra check needed as the output dtype of nvfp4_gemm has been verified when - # ub_copy is inserted. + def extra_check(match: Match) -> bool: + # Validate backend value + backend_value = match.kwargs.get('backend') + if backend_value is None: + # No backend specified, use default - OK + return True + + # backend should be a string literal + if not isinstance(backend_value, str): + return False + + valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'} + return backend_value in valid_backends + register_replacement( empty_nvfp4_gemm_prologue_pattern, target_nvfp4_gemm_prologue_pattern, @@ -594,6 +623,7 @@ def target_nvfp4_gemm_prologue_pattern( fwd_only, custom_pass, search_fn_pattern=ub_copy, + extra_check=extra_check, ) def register_mm_prologue(custom_pass: PatternMatcherPass): diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 669c363e48b..e0c50eccc3d 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -3,6 +3,8 @@ import torch +from tensorrt_llm.logger import logger + from ..._utils import get_sm_version from ...math_utils import pad_up from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, @@ -32,7 +34,7 @@ Sm100BlockScaledPersistentDenseGemmKernel from ..cute_dsl_kernels.blackwell.utils import make_ptr - class CuteDSLNVFP4BlackwellRunner(TunableRunner): + class CuteDSLNVFP4BlackwellLinear(TunableRunner): kernel_class = Sm100BlockScaledPersistentDenseGemmKernel kernel_cache = dict() tuning_config = TuningConfig( @@ -43,19 +45,28 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner): use_cold_l2_cache=True, ) - def __init__(self, alpha: float, output_dtype: torch.dtype): + def __init__(self, + output_dtype: torch.dtype, + to_userbuffers: bool = False): super().__init__() - self.alpha = alpha - self.output_dtype = output_dtype - assert output_dtype == torch.bfloat16 - if get_sm_version() not in [100, 103]: + if output_dtype != torch.bfloat16: raise ValueError( - f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103" + f"CuteDSL NVFP4 only supports bfloat16 output, got {output_dtype}" ) + self.output_dtype = output_dtype + self.to_userbuffers = to_userbuffers def unique_id(self): - return (self.output_dtype, ) + return (self.output_dtype, self.to_userbuffers) + + def __hash__(self): + return hash((self.output_dtype, self.to_userbuffers)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers def get_valid_tactics( self, @@ -63,6 +74,15 @@ def get_valid_tactics( profile: OptimizationProfile, **kwargs, ) -> List[Tuple[int, int]]: + # Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103 + sm_version = get_sm_version() + if sm_version not in [100, 103]: + logger.debug( + f"CuteDSL: SM version {sm_version} is not supported. " + f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics." + ) + return [] + assert inputs[0].dim() == 2 assert inputs[1].dim() == 2 @@ -73,11 +93,44 @@ def get_valid_tactics( real_k = k * 2 batch_size = 1 sf_vec_size = 16 - # m,k + + # Fixed layout for FP4: A and B are always K-major a_major = "k" - # n, k b_major = "k" + # Early exit: Check K dimension alignment + # For K-major layout (A and B tensors), K is the major mode (contiguous dimension). + # 16-byte alignment requirement: K must be divisible by 32 for FP4 (128 bits / 4 bits = 32) + if real_k % 32 != 0: + logger.debug( + f"CuteDSL: K={real_k} does not meet 16-byte alignment requirement " + f"(K%32={real_k%32}, expected 0). Skipping all tactics.") + return [] + + # Optimize swap_ab candidates based on M and N alignment + # swap_ab=False → C is N-major → requires N%8==0 (BF16: 128 bits / 16 bits = 8) + # swap_ab=True → C is M-major → requires M%8==0 + m_aligned = (m % 8 == 0) + n_aligned = (n % 8 == 0) + + if not m_aligned and not n_aligned: + logger.debug( + f"CuteDSL: Neither M={m} nor N={n} meets 16-byte alignment " + f"(M%8={m%8}, N%8={n%8}). No valid C layout. Skipping all tactics." + ) + return [] + + # Only test swap_ab values that satisfy alignment + swap_ab_candidates = [] + if n_aligned: + swap_ab_candidates.append(False) # N-major layout + if m_aligned: + swap_ab_candidates.append(True) # M-major layout + + logger.debug( + f"CuteDSL: M={m}(aligned={m_aligned}), N={n}(aligned={n_aligned}), K={real_k}(aligned=True). " + f"Testing swap_ab={swap_ab_candidates}") + # full shamoo mma_tiler_mn_candidates = [ (128, 64), @@ -134,6 +187,9 @@ def get_valid_tactics( valid_tactics.append( (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch)) + logger.debug( + f"CuteDSL: Found {len(valid_tactics)} valid tactics for M={m}, N={n}, K={real_k}" + ) return valid_tactics def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype, @@ -149,6 +205,7 @@ def forward( self, inputs: List[torch.Tensor], tactic, + **kwargs, ) -> torch.Tensor: """ Performs fp8 blockwise gemm operation using CuTe DSL. @@ -160,8 +217,7 @@ def forward( inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8. inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8. inputs[4]: Alpha scaling factor. dtype: float32. - inputs[5]: Output dtype, expected to be torch.bfloat16. - tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch). + tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn). Returns: torch.Tensor: Output tensor of shape (m, n), dtype: bf16. @@ -179,11 +235,17 @@ def forward( False, ] - a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs + a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, alpha_tensor = inputs m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0] - c_tensor = torch.empty(*(m, n), - dtype=self.output_dtype, - device="cuda") + + # Allocate output tensor from UserBuffers or regular CUDA memory + if self.to_userbuffers: + c_tensor = torch.ops.trtllm.create_userbuffers_tensor( + [m, n], self.output_dtype) + else: + c_tensor = torch.empty(*(m, n), + dtype=self.output_dtype, + device="cuda") if swap_ab: c_tensor = c_tensor.permute(1, 0) @@ -193,9 +255,27 @@ def forward( sf_k = pad_up(real_k // sf_vec_size, 4) sf_n = pad_up(n, 128) - # the scaling tensor is 1D. we need to make sure it has been padded to the correct shape - assert a_sf_tensor.shape == (sf_m * sf_k, ) - assert b_sf_tensor.shape == (sf_n * sf_k, ) + # Reshape scale factors to CuteDSL's expected format + # Input format (from CUTLASS/cuBLASLt): (m*k//16,) and (n*k//16,) + # CuteDSL format: (sf_m*sf_k,) and (sf_n*sf_k,) + # Note: This is just a view change, no memory copy + expected_a_sf_size = sf_m * sf_k + expected_b_sf_size = sf_n * sf_k + + if a_sf_tensor.numel() != expected_a_sf_size: + raise ValueError( + f"CuteDSL: act scale factor size mismatch. " + f"Expected {expected_a_sf_size} (sf_m={sf_m} * sf_k={sf_k}), " + f"got {a_sf_tensor.numel()} for shape M={m}, K={real_k}") + if b_sf_tensor.numel() != expected_b_sf_size: + raise ValueError( + f"CuteDSL: weight scale factor size mismatch. " + f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), " + f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}") + + # Reshape to CuteDSL's expected format (just a view, no copy) + a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k) + b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k) a_ptr = self.make_cute_dsl_global_pointer(a_tensor, cutlass.Float4E2M1FN, 32) @@ -207,6 +287,9 @@ def forward( b_sf_tensor, cutlass.Float8E4M3FN, 16) c_ptr = self.make_cute_dsl_global_pointer(c_tensor, cutlass.BFloat16, 16) + # Create pointer to alpha on device + alpha_ptr = self.make_cute_dsl_global_pointer( + alpha_tensor, cutlass.Float32, 4) # get stream torch_stream = torch.cuda.current_stream() @@ -259,7 +342,7 @@ def forward( kernel_a_sf_ptr, kernel_b_sf_ptr, c_ptr, - self.alpha, + alpha_ptr, # Pass alpha as device pointer max_active_clusters, stream, swap_ab, @@ -283,7 +366,7 @@ def forward( kernel_a_sf_ptr, kernel_b_sf_ptr, c_ptr, - self.alpha, + alpha_ptr, # Pass alpha as device pointer stream, ) @@ -300,20 +383,45 @@ def cute_dsl_nvfp4_gemm_blackwell( weight: torch.Tensor, input_scale: torch.Tensor, weight_scale: torch.Tensor, - alpha: float, + alpha: torch.Tensor, output_dtype: torch.dtype, + to_userbuffers: bool = False, ) -> torch.Tensor: + """CuteDSL-based NVFP4 GEMM optimized for Blackwell. + + Args: + input: Activation tensor [m, k] in FP4 format (packed in uint8) + weight: Weight tensor [n, k] in FP4 format (packed in uint8) + input_scale: Activation scale factors + weight_scale: Weight scale factors + alpha: Scaling factor + output_dtype: Output data type (must be bfloat16) + to_userbuffers: Whether to allocate output from UserBuffers pool + + Note: + This function is primarily used internally by nvfp4_gemm. + Direct usage is discouraged. Consider using nvfp4_gemm instead + for automatic backend selection with better performance. + """ + # Validate SM version before attempting to use CuteDSL + sm_version = get_sm_version() + if sm_version not in [100, 103]: + raise ValueError( + f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. " + f"Please use nvfp4_gemm with backend='auto' for automatic backend selection." + ) tuner = AutoTuner.get() - runner = CuteDSLNVFP4BlackwellRunner(alpha, output_dtype) - inputs = [input, weight, input_scale, weight_scale] + runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers) + inputs = [input, weight, input_scale, weight_scale, alpha] _, best_tactic = tuner.choose_one( "trtllm::cute_dsl_nvfp4_gemm_blackwell", [runner], runner.__class__.tuning_config, inputs, ) + output = runner(inputs, tactic=best_tactic) return output @@ -323,8 +431,9 @@ def _( mat_b: torch.Tensor, input_scale: torch.Tensor, weight_scale: torch.Tensor, - alpha: float, + alpha: torch.Tensor, # Match custom op signature output_dtype: torch.dtype, + to_userbuffers: bool = False, ): # [m, k] shape = list(mat_a.shape) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index f32a4aa27d2..d40c1fd5844 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import List, Mapping, Optional, Tuple +from typing import List, Mapping, Optional, Tuple, Union import torch import triton # type: ignore[import] @@ -8,9 +8,12 @@ import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm import deep_gemm from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.logger import logger from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) +from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE +from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE from ..modules.multi_stream_utils import do_multi_stream from ..modules.swiglu import silu_and_mul_kernel from ..utils import (ActivationType, fp4_scale_infer_shape, @@ -515,6 +518,77 @@ def forward( return result +class CudaCoreNVFP4Runner(TunableRunner): + """ + CUDA Core-based NVFP4 GEMM runner. + + This runner is available on: + - SM >= 100 (Blackwell) + - M <= 8 (small batch size limitation from kernel template) + """ + + # Shared tuning config (no tactics needed, single implementation) + tuning_config = TuningConfig() + + # Minimum supported architecture: SM100 (Blackwell) + MIN_SM_VERSION = 100 + # Maximum M dimension (from cudaCoreGemmTemplateMaxM in C++ kernel) + MAX_M_DIMENSION = 8 + + def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype): + super().__init__() + self.to_userbuffers = to_userbuffers + self.output_dtype = output_dtype + + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, **kwargs) -> List[int]: + """Return [0] if architecture and shape requirements are met, otherwise [].""" + # Check architecture support at runtime + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability( + torch.device('cuda:0')) + sm_version = capability[0] * 10 + capability[1] + if sm_version < self.MIN_SM_VERSION: + return [] + else: + return [] + + # Check M dimension limitation (kernel template constraint) + act_fp4, weight, act_sf, weight_scale, alpha = inputs + m = act_fp4.shape[0] + if m > self.MAX_M_DIMENSION: + return [] + + # Single tactic (no config variations) + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + ) -> torch.Tensor: + act_fp4, weight, act_sf, weight_scale, alpha = inputs + + # Unswizzle the activation scale factors + # act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm + m = act_fp4.shape[0] + act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + act_sf.view((m + 128 - 1) // 128 * 128, -1)) + + # Call CUDA Core NVFP4 GEMM + result = torch.ops.trtllm.cuda_core_nvfp4_gemm( + act_fp4, + weight, + scale_a=act_sf_unswizzled, + scale_b=weight_scale, + alpha=alpha, + bias=None, + out_dtype=self.output_dtype, + to_userbuffers=self.to_userbuffers, + ) + return result + + @torch.library.custom_op("trtllm::nvfp4_gemm_cublaslt", mutates_args=()) def nvfp4_gemm_cublaslt( act_fp4: torch.Tensor, @@ -525,7 +599,13 @@ def nvfp4_gemm_cublaslt( output_dtype: torch.dtype, to_userbuffers: bool = False, ) -> torch.Tensor: - """cuBLASLt-based NVFP4 GEMM with heuristic-based auto-tuning.""" + """cuBLASLt-based NVFP4 GEMM with heuristic-based auto-tuning. + + Note: + This function is primarily used internally by nvfp4_gemm. + Direct usage is discouraged. Consider using nvfp4_gemm instead + for automatic backend selection with better performance. + """ tuner = AutoTuner.get() # Use CublasLt runner with heuristic-based tuning @@ -562,8 +642,8 @@ def _( dtype=output_dtype) -@torch.library.custom_op("trtllm::nvfp4_gemm", mutates_args=()) -def nvfp4_gemm( +@torch.library.custom_op("trtllm::nvfp4_gemm_cutlass", mutates_args=()) +def nvfp4_gemm_cutlass( act_fp4: torch.Tensor, weight: torch.Tensor, act_sf: torch.Tensor, @@ -572,7 +652,13 @@ def nvfp4_gemm( output_dtype: torch.dtype, to_userbuffers: bool = False, ) -> torch.Tensor: - """CUTLASS-based NVFP4 GEMM with auto-tuning.""" + """CUTLASS-based NVFP4 GEMM with auto-tuning. + + Note: + This function is primarily used internally by nvfp4_gemm. + Direct usage is discouraged. Consider using nvfp4_gemm instead + for automatic backend selection with better performance. + """ tuner = AutoTuner.get() # Use Cutlass runner with predefined configs @@ -592,6 +678,285 @@ def nvfp4_gemm( tactic=best_tactic) +@nvfp4_gemm_cutlass.register_fake +def _( + act_fp4: torch.Tensor, + weight: torch.Tensor, + act_sf: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + output_dtype: torch.dtype, + to_userbuffers: bool = False, +) -> torch.Tensor: + return act_fp4.new_empty((act_fp4.size(0), weight.size(0)), + dtype=output_dtype) + + +class NVFP4GemmUnifiedRunner(TunableRunner): + runner_dict = dict() + + def __init__(self, + to_userbuffers: bool, + output_dtype: torch.dtype, + backend: str = "auto"): + super().__init__() + self.to_userbuffers = to_userbuffers + self.output_dtype = output_dtype + self.backend = backend + + def unique_id(self): + """Include backend in cache key to avoid sharing cache across backends.""" + return (self.to_userbuffers, self.output_dtype, self.backend) + + def get_valid_tactics(self, inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs) -> List[Tuple]: + # return valid nvfp4 gemm implementations + tactics = [] + act_fp4, weight, act_sf, weight_scale, alpha = inputs + backend = self.backend + + if backend in ["auto", "cuda_core"]: + is_cuda_core_supported = False + m = act_fp4.shape[0] + sm_version = None + + if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability( + torch.device('cuda:0')) + sm_version = capability[0] * 10 + capability[1] + # Check both SM version and M dimension constraints + is_cuda_core_supported = ( + sm_version >= CudaCoreNVFP4Runner.MIN_SM_VERSION + and m <= CudaCoreNVFP4Runner.MAX_M_DIMENSION) + + if is_cuda_core_supported: + tactics.append("cuda_core") + elif backend == "cuda_core": + # Explicitly requested but conditions not met - raise error + error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. " + error_msg += f"Current: SM={sm_version if sm_version else 'N/A'}, M={m}. " + error_msg += "Please use backend='auto' or another backend." + raise ValueError(error_msg) + + # Add CUTLASS runner (always available) + if backend in ["auto", "cutlass"]: + tactics.append("cutlass") + + # Add cuBLASLt runner if available + if backend in ["auto", "cublaslt"]: + if IS_CUBLASLT_AVAILABLE: + tactics.append("cublaslt") + elif backend == "cublaslt": + raise ValueError( + "cuBLASLt backend is not available. " + "Please check cuBLASLt installation or use backend='auto'.") + + # Add CuteDSL runner if available + if backend in ["auto", "cutedsl"]: + if IS_CUTLASS_DSL_AVAILABLE: + # Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200) + sm_version = get_sm_version() + if sm_version not in [100, 103]: + if backend == "cutedsl": + # Explicitly requested CuteDSL but SM version not supported + raise ValueError( + f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. " + f"CuteDSL NVFP4 is not supported on this GPU architecture. " + f"Please use backend='auto' to automatically select a compatible backend." + ) + # else: backend='auto' → silently skip CuteDSL + else: + # SM version OK, check if CuteDSL supports the current shape + from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \ + CuteDSLNVFP4BlackwellLinear + cutedsl_runner = CuteDSLNVFP4BlackwellLinear( + self.output_dtype) + cutedsl_tactics = cutedsl_runner.get_valid_tactics( + inputs, profile) + + if cutedsl_tactics: + # CuteDSL supports this shape + tactics.append("cutedsl") + elif backend == "cutedsl": + # Explicitly requested CuteDSL but it doesn't support this shape + m, n, k = inputs[0].shape[0], inputs[1].shape[ + 0], inputs[0].shape[1] * 2 + raise ValueError( + f"CuteDSL backend does not support the current shape:\n" + f" M={m}, N={n}, K={k}\n" + f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n" + f" - K must be divisible by 32 (FP4 K-major layout): K%32={'0✓' if k % 32 == 0 else str(k%32)+'✗'}\n" + f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n" + f"Please use backend='auto' to automatically select a compatible backend." + ) + # else: backend='auto' and CuteDSL doesn't support shape → silently skip + elif backend == "cutedsl": + raise ValueError( + "CuteDSL backend is not available. " + "Please check CuteDSL installation or use backend='auto'.") + + return tactics + + def forward( + self, + inputs: List[torch.Tensor], + tactic: Union[ + str, int] = "cutlass", # str: backend name, or int: -1 for fallback + **kwargs, + ) -> torch.Tensor: + act_fp4, weight, act_sf, weight_scale, alpha = inputs + + requested_backend = self.backend + + # If a specific backend was requested (not 'auto') and we're using fallback tactic + # This can happen on cache miss, where AutoTuner uses tactic=-1 as default + if requested_backend != 'auto' and requested_backend != tactic and tactic == -1: + # User explicitly requested a backend, but we're falling back to default + # This might happen on cache miss. We should validate the requested backend supports this shape. + + # Get valid tactics for the requested backend + from tensorrt_llm._torch.autotuner import OptimizationProfile + valid_tactics = self.get_valid_tactics(inputs, + OptimizationProfile()) + + if not valid_tactics or requested_backend not in valid_tactics: + # Requested backend doesn't support this shape + m, n, k = inputs[0].shape[0], inputs[1].shape[ + 0], inputs[0].shape[1] * 2 + raise ValueError( + f"Backend '{requested_backend}' was explicitly requested but does not support the current shape:\n" + f" M={m}, N={n}, K={k}\n" + f"Please use backend='auto' to automatically select a compatible backend." + ) + + # Backend supports it, use the requested backend instead of fallback + tactic = requested_backend + + if tactic == "cuda_core": + # Unswizzle the activation scale factors + # act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm + m = act_fp4.shape[0] + act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( + act_sf.view((m + 128 - 1) // 128 * 128, -1)) + + # Call CUDA Core NVFP4 GEMM + return torch.ops.trtllm.cuda_core_nvfp4_gemm( + act_fp4, + weight, + act_sf_unswizzled, + weight_scale, + alpha, + bias=None, + out_dtype=self.output_dtype, + to_userbuffers=self.to_userbuffers) + elif tactic == "cutlass": + return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf, + weight_scale, alpha, + self.output_dtype, + self.to_userbuffers) + elif tactic == "cublaslt": + return torch.ops.trtllm.nvfp4_gemm_cublaslt(act_fp4, weight, act_sf, + weight_scale, alpha, + self.output_dtype, + self.to_userbuffers) + elif tactic == "cutedsl": + return torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell( + act_fp4, weight, act_sf, weight_scale, alpha, self.output_dtype, + self.to_userbuffers) + elif tactic == -1: + return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf, + weight_scale, alpha, + self.output_dtype, + self.to_userbuffers) + else: + raise ValueError(f"Invalid tactic: {tactic}") + + +@torch.library.custom_op("trtllm::nvfp4_gemm", mutates_args=()) +def nvfp4_gemm( + act_fp4: torch.Tensor, + weight: torch.Tensor, + act_sf: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + output_dtype: torch.dtype, + to_userbuffers: bool = False, + backend: str = "auto", +) -> torch.Tensor: + """Unified NVFP4 GEMM with automatic or manual backend selection. + + This function can automatically choose the best backend or force a specific backend: + - CUTLASS: Predefined CUTLASS configurations with auto-tuning + - cuBLASLt: Heuristic-based algorithms from cuBLASLt library + - CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid) + - CUDA Core: CUDA Core implementation (requires SM >= 100 and M <= 8) + + The AutoTuner profiles all available backends during the first run and caches + the best choice for each input shape. Subsequent calls use the cached selection + with zero overhead. In 'auto' mode, backends are only considered if their + requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8). + + Args: + act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8) + weight: Weight tensor [n, k] in FP4 format (packed in uint8) + act_sf: Activation scale factors + weight_scale: Weight scale factors + alpha: Scaling factor (as torch.Tensor for CUTLASS/cuBLASLt compatibility) + output_dtype: Output data type + to_userbuffers: Whether to use user buffers (CUTLASS/cuBLASLt only) + backend: Backend selection, one of: + - 'auto': AutoTuner automatically selects best backend (default) + - 'cutlass': Force use CUTLASS (FP4GemmRunner) + - 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner) + - 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper) + - 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8) + + Returns: + Output tensor [m, n] with dtype=output_dtype + + Raises: + ValueError: If backend is invalid/unavailable + """ + + # Validate backend parameter + valid_backends = ['auto', 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'] + if backend not in valid_backends: + raise ValueError( + f"Invalid backend '{backend}'. Must be one of {valid_backends}") + + # Build list of runners based on backend parameter + runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backend) + + # Use AutoTuner to select best runner and tactic + # - For 'auto' mode: compare across all backends, find global optimum + # - For forced backend: only one backend in list, but still find its best tactic + tuner = AutoTuner.get() + + try: + _, best_tactic = tuner.choose_one( + "trtllm::nvfp4_gemm::gemm", + [runner], + FP4GemmRunner. + tuning_config, # All runners use the same tuning_config + [act_fp4, weight, act_sf, weight_scale, alpha], + ) + except IndexError as e: + # Provide more helpful error message + logger.error( + f"shapes: M={act_fp4.shape[0]}, K={act_fp4.shape[1]*2}, N={weight.shape[0]}" + ) + raise RuntimeError( + f"AutoTuner failed to find a valid (runner, tactic) pair. " + f"Input shape: M={act_fp4.shape[0]}, K={act_fp4.shape[1]*2}, N={weight.shape[0]}" + ) from e + + return runner( + inputs=[act_fp4, weight, act_sf, weight_scale, alpha], + tactic=best_tactic, + ) + + @nvfp4_gemm.register_fake def _( act_fp4: torch.Tensor, @@ -601,7 +966,9 @@ def _( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool = False, + backend: str = "auto", ) -> torch.Tensor: + """Fake implementation for torch.compile support.""" return act_fp4.new_empty((act_fp4.size(0), weight.size(0)), dtype=output_dtype) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py index 11e8e74d7a1..6b6b427edca 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py @@ -321,7 +321,7 @@ def __call__( sfa_tensor: cute.Tensor, sfb_tensor: cute.Tensor, c_tensor: cute.Tensor, - alpha: cutlass.Float32, + alpha: cute.Tensor, # Single-element tensor containing alpha value max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, @@ -607,11 +607,13 @@ def kernel( epi_tile: cute.Tile, tile_sched_params: utils.PersistentTileSchedulerParams, epilogue_op: cutlass.Constexpr, - alpha: cutlass.Float32, + alpha: cute.Tensor, ): """ GPU device kernel performing the Persistent batched GEMM computation. """ + alpha_value = alpha[0].to(self.c_dtype) + warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) @@ -1365,6 +1367,7 @@ def kernel( # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): real_subtile_idx = subtile_idx if cutlass.const_expr(self.overlapping_accum): @@ -1392,8 +1395,8 @@ def kernel( # Convert to C type # acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op( - alpha.to(self.c_dtype) * acc_vec.to(self.c_dtype)) + acc_vec = epilogue_op(alpha_value * + acc_vec.to(self.c_dtype)) tRS_rC.store(acc_vec) # @@ -2026,7 +2029,8 @@ def wrapper( a_sf_ptr: cute.Pointer, b_sf_ptr: cute.Pointer, c_ptr: cute.Pointer, - alpha: cutlass.Float32, + alpha: cute. + Pointer, # Device pointer to alpha, will be converted to Tensor max_active_clusters: cutlass.Constexpr, current_stream: cuda.CUstream, swap_ab: cutlass.Constexpr = False, @@ -2047,7 +2051,7 @@ def wrapper( a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A. b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B. c_ptr (cute.Pointer): Pointer to the C tensor. - alpha (cutlass.Float32): Scaling factor for the GEMM output. + alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally). max_active_clusters (cutlass.Constexpr): Maximum number of active clusters. current_stream (cuda.CUstream): CUDA stream for the operation. @@ -2092,8 +2096,11 @@ def wrapper( (32, 4, sf_n, 4, sf_k, l), order=(2, 1, 4, 0, 3, 5), )) + alpha_tensor = cute.make_tensor(alpha, + layout=cute.make_ordered_layout( + (1, ), order=(0, ))) - self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha, + self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor, max_active_clusters, current_stream, epilogue_op) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index e24cf5f583c..1a1769bfde6 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -28,8 +28,6 @@ from ..._utils import get_sm_version, is_sm_100f from ...models.modeling_utils import QuantConfig -from ..cublaslt_utils import IS_CUBLASLT_AVAILABLE -from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE from ..utils import Fp4QuantizedTensor, unswizzle_sf @@ -916,32 +914,15 @@ def apply(self, module: Linear, input: torch.Tensor, act_fp4, act_sf = torch.ops.trtllm.fp4_quantize( input, module.input_scale, module.scaling_vector_size, False) - if IS_CUTLASS_DSL_AVAILABLE and module.use_cute_dsl_nvfp4_blockscaling_mm: - output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell( - act_fp4, module.weight, act_sf, module.weight_scale, - module.scalar_alpha, module.dtype) - elif IS_CUBLASLT_AVAILABLE and module.use_cublaslt_nvfp4_blockscaling_mm: - output = torch.ops.trtllm.nvfp4_gemm_cublaslt( - act_fp4, module.weight, act_sf, module.weight_scale, - module.alpha, module.dtype) - else: - if module.enable_cuda_core and act_fp4.shape[0] <= 8: - act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( - act_sf.view((act_fp4.shape[0] + 128 - 1) // 128 * 128, -1)) - output = torch.ops.trtllm.cuda_core_nvfp4_gemm( - act_fp4, - module.weight, - scale_a=act_sf_unswizzled, - scale_b=module.weight_scale, - alpha=module.alpha, - bias=None, - out_dtype=module.dtype or input.dtype, - ) - else: - output = torch.ops.trtllm.nvfp4_gemm(act_fp4, module.weight, - act_sf, - module.weight_scale, - module.alpha, module.dtype) + # Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL + output = torch.ops.trtllm.nvfp4_gemm(act_fp4, + module.weight, + act_sf, + module.weight_scale, + module.alpha, + module.dtype, + to_userbuffers=False, + backend=module.nvfp4_backend) # Take the dim of out_features if padded. Make sure the output is contiguous if output.shape[-1] > module.out_features: output = output[..., :module.out_features].contiguous() @@ -2012,11 +1993,17 @@ def __init__( allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, use_cute_dsl_blockscaling_mm: bool = False, - use_cute_dsl_nvfp4_blockscaling_mm: bool = False, - use_cublaslt_nvfp4_blockscaling_mm: bool = False, disable_deep_gemm: bool = False, fused_weight_shard_indices_mapping: Optional[dict] = None, + nvfp4_backend: str = "auto", ): + """ + Args: + nvfp4_backend: Backend selection for NVFP4 GEMM operations. + Supported values: "auto", "cutlass", "cublaslt", "cutedsl". + Default is "auto" which automatically selects the best backend. + Can be overridden via TRTLLM_NVFP4_GEMM_BACKEND environment variable. + """ from ..distributed import AllReduce super().__init__() @@ -2033,11 +2020,24 @@ def __init__( self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm - self.use_cute_dsl_nvfp4_blockscaling_mm = use_cute_dsl_nvfp4_blockscaling_mm - self.use_cublaslt_nvfp4_blockscaling_mm = use_cublaslt_nvfp4_blockscaling_mm self.disable_deep_gemm = disable_deep_gemm self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping + # Support environment variable override for nvfp4_backend + nvfp4_backend_value = os.environ.get('TRTLLM_NVFP4_GEMM_BACKEND', + nvfp4_backend) + + # Validate backend selection + valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'} + if nvfp4_backend_value not in valid_backends: + raise ValueError( + f"Invalid nvfp4_backend: '{nvfp4_backend_value}'. " + f"Supported values are: {', '.join(sorted(valid_backends))}. " + f"Set via constructor argument or TRTLLM_NVFP4_GEMM_BACKEND environment variable." + ) + + self.nvfp4_backend = nvfp4_backend_value + local_in_features = in_features local_out_features = out_features diff --git a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py index bc78185e54e..a549b52fa4e 100644 --- a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py +++ b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py @@ -39,7 +39,8 @@ def test_fp4_linear(dtype, mnk): out_features=OUTPUT_SIZE, bias=False, dtype=dtype, - quant_config=qc) + quant_config=qc, + nvfp4_backend='cutlass') # Force CUTLASS to match reference assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -116,7 +117,7 @@ def test_fp4_linear_cute_dsl(dtype, mnk): bias=False, dtype=dtype, quant_config=qc, - use_cute_dsl_nvfp4_blockscaling_mm=True) + nvfp4_backend='cutedsl') assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -179,7 +180,7 @@ def fp4_linear_perf_test(dtype, SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE): bias=False, dtype=dtype, quant_config=qc, - use_cute_dsl_nvfp4_blockscaling_mm=True) + nvfp4_backend='cutedsl') assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -214,7 +215,7 @@ def fp4_linear_perf_test(dtype, SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE): bias=False, dtype=dtype, quant_config=qc, - use_cute_dsl_nvfp4_blockscaling_mm=False) + nvfp4_backend='cutlass') # Use CUTLASS as reference assert l_fp4_ref.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4_ref.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -324,7 +325,7 @@ def nvfp4_gemm_perf_test( f"ref tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}", color="orange"): with torch.inference_mode(), autotune(): - output_ref = torch.ops.trtllm.nvfp4_gemm( + output_ref = torch.ops.trtllm.nvfp4_gemm_cutlass( x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_tensor, dtype) torch.testing.assert_close(output, output_ref) print(f"PASSED") @@ -367,7 +368,7 @@ def nvfp4_gemm_perf_test( f"ref warmup, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}", color="red"): for _ in range(warmup_iterations): - output_ref = torch.ops.trtllm.nvfp4_gemm( + output_ref = torch.ops.trtllm.nvfp4_gemm_cutlass( x_fp4_list[buffer_idx % workspace_count], w_fp4_list[buffer_idx % workspace_count], x_sf_block_list[buffer_idx % workspace_count], @@ -380,7 +381,7 @@ def nvfp4_gemm_perf_test( f"ref run, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}", color="red"): for i in range(iterations): - output_ref = torch.ops.trtllm.nvfp4_gemm( + output_ref = torch.ops.trtllm.nvfp4_gemm_cutlass( x_fp4_list[buffer_idx % workspace_count], w_fp4_list[buffer_idx % workspace_count], x_sf_block_list[buffer_idx % workspace_count], @@ -391,6 +392,243 @@ def nvfp4_gemm_perf_test( buffer_idx = buffer_idx + 1 +@skip_pre_blackwell +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize( + "mnk", + [ + # Small batch sizes (M <= 16) - test small M handling + (1, 4096, 4096, "Batch=1, Square 4K"), + (4, 4096, 4096, "Batch=4, Square 4K"), + (16, 4096, 4096, "Batch=16, Square 4K"), + + # Odd M values + (3, 4096, 4096, "Odd M: M=3"), + (7, 4096, 4096, "Odd M: M=7"), + (9, 4096, 4096, "Odd M: M=9"), + + # Medium batch sizes - common inference scenarios + (128, 4096, 4096, "Batch=128, Square 4K"), + (128, 7168, 16384, "Batch=128, Large K/N"), + (128, 4096, 7168, "Batch=128, Asymmetric"), + + # Large batch sizes - training scenarios + (512, 4096, 4096, "Batch=512, Square 4K"), + (1024, 4096, 4096, "Batch=1024, Square 4K"), + + # Very large batch - maximum performance + (2048, 4096, 4096, "Batch=2048, Square 4K"), + (4096, 4096, 4096, "Batch=4096, Square 4K"), + + # Large K and N - test memory bandwidth + (128, 8192, 8192, "Batch=128, Square 8K"), + (256, 16384, 16384, "Batch=256, Square 16K"), + + # Size asymmetry tests + (1024, 128, 4096, "Wide M: M >> N"), + (128, 16384, 128, "Wide N: N >> K"), + ]) +def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): + """Test nvfp4_gemm with auto backend selection, ensuring all tactics are tested.""" + from tensorrt_llm._torch.autotuner import AutoTuner, autotune + from tensorrt_llm._torch.cublaslt_utils import IS_CUBLASLT_AVAILABLE + + # Unpack mnk with optional description + if len(mnk) == 4: + SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE, desc = mnk + else: + SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE = mnk + desc = f"M={SEQ_LEN}, K={HIDDEN_SIZE}, N={OUTPUT_SIZE}" + torch.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + x_sf_global = (448 * 6) / x.abs().max().float() + + w = torch.randn((OUTPUT_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() + w_sf_global = (448 * 6) / w.abs().max().float() + w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global, + scaling_vector_size, + False) + + # Prepare input + with torch.inference_mode(): + x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize( + x, x_sf_global, scaling_vector_size, False) + alpha_ref = 1.0 / (w_sf_global * x_sf_global) + alpha_tensor = torch.tensor(alpha_ref, dtype=torch.float32).cuda() + + # Reference: Use CUTLASS backend explicitly for reference output + with torch.inference_mode(): + output_ref = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='cutlass') + + # Test auto backend selection with autotuning + with torch.inference_mode(), autotune(): + output_auto = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='auto') + + AutoTuner.get().print_profiling_cache() + + # Verify auto mode result matches reference + torch.cuda.synchronize() + torch.testing.assert_close(output_auto, output_ref, rtol=1e-2, atol=0.15) + + # Test all combinations of outer layer (backend selection) and inner layer (backend tactics) + # Outer layer: nvfp4_gemm selects backend + # Inner layer: each backend has its own tactics + from collections import defaultdict + + print(f"\n{'='*80}") + print(f"Testing nvfp4_gemm (2-layer tactics): {desc}") + print(f"Shape: M={SEQ_LEN}, K={HIDDEN_SIZE}, N={OUTPUT_SIZE}") + print(f"{'='*80}") + + print(f"\n[Outer Layer] Capturing backend selection tactics...") + with AutoTuner.get().capture() as outer_capture, torch.inference_mode(): + output = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='auto') + + outer_tactics_list = list(outer_capture) + print(f" Found {len(outer_tactics_list)} outer layer tactics (backends)") + + # Parse outer tactics to get backend names + backend_map = {} + for outer_tactic in outer_tactics_list: + outer_runner, backend_name = outer_tactic[0] + backend_map[backend_name] = outer_tactic + print(f" - Backend: {backend_name}") + + print(f"\n[Inner Layer] Testing tactics for each backend...") + + # All backends have independent APIs, but cuda_core needs special handling, because it requires unswizzled scale factors + backend_apis = {} + if IS_CUTLASS_DSL_AVAILABLE: + if 'cutlass' in backend_map: + backend_apis['cutlass'] = torch.ops.trtllm.nvfp4_gemm_cutlass + if IS_CUBLASLT_AVAILABLE: + if 'cublaslt' in backend_map: + backend_apis['cublaslt'] = torch.ops.trtllm.nvfp4_gemm_cublaslt + if IS_CUTLASS_DSL_AVAILABLE: + if 'cutedsl' in backend_map: + backend_apis[ + 'cutedsl'] = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell + + # cuda_core needs special handling (different parameters, single tactic) + test_cuda_core = 'cuda_core' in backend_map + + # Step 3: For each backend, capture and immediately test all tactics + # Must test immediately after capture to avoid _last_capture being overwritten + tactics_by_backend = defaultdict(list) + total_tactics_tested = 0 + + for backend_name, backend_api in backend_apis.items(): + print(f"\n Backend: {backend_name}") + + # Capture inner tactics for this backend + with AutoTuner.get().capture() as inner_capture, torch.inference_mode(): + output = backend_api( + x_fp4, # input/act_fp4 + w_fp4, # weight + x_sf_block, # input_scale/act_sf + w_sf_block, # weight_scale + alpha_tensor, # alpha + dtype # output_dtype + ) + + inner_tactics_list = list(inner_capture) + print(f" Found {len(inner_tactics_list)} inner tactics") + + # Verify tactics uniqueness (ensure we're testing different tactics, not repeating the same one) + tactic_values = [t[0][1] for t in inner_tactics_list] + unique_tactics = len(set(tactic_values)) + assert len(tactic_values) == unique_tactics, \ + f"Duplicate tactics detected! Total: {len(tactic_values)}, Unique: {unique_tactics}" + + # Test each tactic immediately (while _last_capture is still valid) + for tactic_idx, inner_tactic in enumerate(inner_tactics_list): + inner_runner, inner_tactic_value = inner_tactic[0] + runner_name = inner_runner.__class__.__name__ + + # Replay this tactic + with AutoTuner.get().replay(inner_tactic), torch.inference_mode(): + # Call backend API directly (using positional args) + output = backend_api( + x_fp4, # input/act_fp4 + w_fp4, # weight + x_sf_block, # input_scale/act_sf + w_sf_block, # weight_scale + alpha_tensor, # alpha + dtype # output_dtype + ) + + # Verify correctness + torch.testing.assert_close(output, + output_ref, + rtol=1e-2, + atol=0.15) + + total_tactics_tested += 1 + tactics_by_backend[runner_name].append(total_tactics_tested) + print(f" ✓ Tactic {tactic_idx+1}/{len(inner_tactics_list)}: " + f"{runner_name} tactic={inner_tactic_value} - PASSED") + + # Step 4: Test cuda_core if it's available (single tactic, no capture needed) + if test_cuda_core: + print(f"\n Backend: cuda_core") + print(f" Found 1 tactic (single implementation, no autotuning)") + + with torch.inference_mode(): + output_cuda_core = torch.ops.trtllm.nvfp4_gemm( + act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='cuda_core') + + torch.testing.assert_close(output_cuda_core, + output_ref, + rtol=1e-2, + atol=0.15) + + total_tactics_tested += 1 + tactics_by_backend['CudaCoreNVFP4Runner'].append(total_tactics_tested) + print(f" ✓ Tactic 1/1: CudaCoreNVFP4Runner tactic=0 - PASSED") + + print(f"\n{'='*80}") + print(f"All {total_tactics_tested} tactics verified successfully!") + print(f"\nBreakdown by backend:") + for runner_name, indices in tactics_by_backend.items(): + print(f" - {runner_name}: {len(indices)} tactics") + if test_cuda_core: + print(f"\n Note: cuda_core has no autotuning (single tactic)") + print(f" Note: Tested all inner layer tactics for each backend") + print( + f" Outer layer (backend selection) was tested separately with backend='auto'" + ) + print(f"{'='*80}\n") + + @pytest.mark.skipif( get_sm_version() not in [100, 103], reason="This test is only supported in Blackwell architecture", @@ -400,7 +638,7 @@ def nvfp4_gemm_perf_test( (128, 2112, 7168), (128, 4096, 7168), (128, 7168, 2048), [127, 1024, 3200]]) def test_fp4_linear_cublaslt(dtype, mnk): - """Test cuBLASLt FP4 GEMM implementation and compare with nvfp4_gemm""" + """Test cuBLASLt FP4 GEMM implementation and compare with nvfp4_gemm_cutlass""" from tensorrt_llm._torch.cublaslt_utils import IS_CUBLASLT_AVAILABLE if not IS_CUBLASLT_AVAILABLE: pytest.skip("cuBLASLt FP4 GEMM not available in this build") @@ -434,17 +672,76 @@ def test_fp4_linear_cublaslt(dtype, mnk): alpha=alpha_tensor, output_dtype=dtype) - # Reference implementation: use torch.ops.trtllm.nvfp4_gemm (CUTLASS) + # Reference implementation: use torch.ops.trtllm.nvfp4_gemm_cutlass (CUTLASS) with torch.inference_mode(): - output_cutlass = torch.ops.trtllm.nvfp4_gemm(x_fp4, w_fp4, x_sf_block, - w_sf_block, alpha_ref, - dtype) + output_cutlass = torch.ops.trtllm.nvfp4_gemm_cutlass( + x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_ref, dtype) # Compare results torch.cuda.synchronize() torch.testing.assert_close(output_cublaslt, output_cutlass) +@pytest.mark.skipif( + get_sm_version() < 100, + reason="CUDA Core backend requires SM >= 100 (Blackwell or newer)", +) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mnk", [(1, 4096, 7168), (4, 7168, 16384), + (8, 2112, 7168)]) +def test_fp4_linear_cuda_core(dtype, mnk): + """Test CUDA Core NVFP4 GEMM implementation on SM >= 100 (M <= 8)""" + + SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE = mnk + torch.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + x_sf_global = (448 * 6) / x.abs().max().float() + + w = torch.randn((OUTPUT_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() + w_sf_global = (448 * 6) / w.abs().max().float() + w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global, + scaling_vector_size, + False) + + with torch.inference_mode(): + x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize( + x, x_sf_global, scaling_vector_size, False) + + alpha_ref = 1.0 / (w_sf_global * x_sf_global) + alpha_tensor = torch.tensor(alpha_ref, dtype=torch.float32).cuda() + + # Reference: Use CUTLASS backend + output_ref = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='cutlass') + + # Test CUDA Core backend + output_cuda_core = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + backend='cuda_core') + + # Compare results + torch.cuda.synchronize() + torch.testing.assert_close(output_cuda_core, + output_ref, + rtol=1e-2, + atol=0.15) + print( + f"✓ CUDA Core test passed for M={SEQ_LEN}, N={OUTPUT_SIZE}, K={HIDDEN_SIZE}" + ) + + if __name__ == "__main__": # m, n, k fp4_linear_perf_test(torch.bfloat16, 128, 7168, 16384)