diff --git a/requirements.txt b/requirements.txt index aaf2884f3d4..e123aafcdee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -73,3 +73,5 @@ nvidia-cutlass-dsl==4.3.1; python_version >= "3.10" plotly numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing partial_json_parser +apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead +torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 842c48725f8..c1881a22e6a 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -240,7 +240,8 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner): def __init__(self, output_dtype: torch.dtype, - to_userbuffers: bool = False): + to_userbuffers: bool = False, + use_tvm_ffi: bool = True): super().__init__() if output_dtype != torch.bfloat16: @@ -249,17 +250,19 @@ def __init__(self, ) self.output_dtype = output_dtype self.to_userbuffers = to_userbuffers + self.use_tvm_ffi = use_tvm_ffi def unique_id(self): - return (self.output_dtype, self.to_userbuffers) + return (self.output_dtype, self.to_userbuffers, self.use_tvm_ffi) def __hash__(self): - return hash((self.output_dtype, self.to_userbuffers)) + return hash( + (self.output_dtype, self.to_userbuffers, self.use_tvm_ffi)) def __eq__(self, other): if not isinstance(other, self.__class__): return False - return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers + return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers and self.use_tvm_ffi == other.use_tvm_ffi def get_valid_tactics( self, @@ -465,51 +468,94 @@ def forward( f"CuteDSL: weight scale factor size mismatch. " f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), " f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}") + if alpha_tensor.numel() != 1: + raise ValueError(f"CuteDSL: alpha size mismatch. " + f"Expected 1, got {alpha_tensor.numel()}") # Reshape to CuteDSL's expected format (just a view, no copy) a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k) b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k) - a_ptr = self.make_cute_dsl_global_pointer(a_tensor, - cutlass.Float4E2M1FN, 32) - b_ptr = self.make_cute_dsl_global_pointer(b_tensor, - cutlass.Float4E2M1FN, 32) - a_sf_ptr = self.make_cute_dsl_global_pointer( - a_sf_tensor, cutlass.Float8E4M3FN, 16) - b_sf_ptr = self.make_cute_dsl_global_pointer( - b_sf_tensor, cutlass.Float8E4M3FN, 16) - c_ptr = self.make_cute_dsl_global_pointer(c_tensor, - cutlass.BFloat16, 16) - # Create pointer to alpha on device - alpha_ptr = self.make_cute_dsl_global_pointer( - alpha_tensor, cutlass.Float32, 4) - - # get stream - torch_stream = torch.cuda.current_stream() - stream = cuda.CUstream(torch_stream.cuda_stream) + if not self.use_tvm_ffi: + a_ptr = self.make_cute_dsl_global_pointer( + a_tensor, cutlass.Float4E2M1FN, 32) + b_ptr = self.make_cute_dsl_global_pointer( + b_tensor, cutlass.Float4E2M1FN, 32) + a_sf_ptr = self.make_cute_dsl_global_pointer( + a_sf_tensor, cutlass.Float8E4M3FN, 16) + b_sf_ptr = self.make_cute_dsl_global_pointer( + b_sf_tensor, cutlass.Float8E4M3FN, 16) + c_ptr = self.make_cute_dsl_global_pointer( + c_tensor, cutlass.BFloat16, 16) + alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor) + + # get stream + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch) if swap_ab: - kernel_a_ptr = b_ptr - kernel_a_sf_ptr = b_sf_ptr - kernel_b_ptr = a_ptr - kernel_b_sf_ptr = a_sf_ptr kernel_m = n kernel_n = m kernel_sf_m = sf_n kernel_sf_n = sf_m + + kernel_a_tensor = b_tensor + kernel_a_sf_tensor = b_sf_tensor + kernel_b_tensor = a_tensor + kernel_b_sf_tensor = a_sf_tensor + + if not self.use_tvm_ffi: + kernel_a_ptr = b_ptr + kernel_a_sf_ptr = b_sf_ptr + kernel_b_ptr = a_ptr + kernel_b_sf_ptr = a_sf_ptr else: - kernel_a_ptr = a_ptr - kernel_a_sf_ptr = a_sf_ptr - kernel_b_ptr = b_ptr - kernel_b_sf_ptr = b_sf_ptr kernel_m = m kernel_n = n kernel_sf_m = sf_m kernel_sf_n = sf_n + kernel_a_tensor = a_tensor + kernel_a_sf_tensor = a_sf_tensor + kernel_b_tensor = b_tensor + kernel_b_sf_tensor = b_sf_tensor + + if not self.use_tvm_ffi: + kernel_a_ptr = a_ptr + kernel_a_sf_ptr = a_sf_ptr + kernel_b_ptr = b_ptr + kernel_b_sf_ptr = b_sf_ptr + if cache_key not in self.__class__.kernel_cache: + if self.use_tvm_ffi: + a_ptr = self.make_cute_dsl_global_pointer( + a_tensor, cutlass.Float4E2M1FN, 32) + b_ptr = self.make_cute_dsl_global_pointer( + b_tensor, cutlass.Float4E2M1FN, 32) + a_sf_ptr = self.make_cute_dsl_global_pointer( + a_sf_tensor, cutlass.Float8E4M3FN, 16) + b_sf_ptr = self.make_cute_dsl_global_pointer( + b_sf_tensor, cutlass.Float8E4M3FN, 16) + c_ptr = self.make_cute_dsl_global_pointer( + c_tensor, cutlass.BFloat16, 16) + alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor) + # make faked stream + stream = cute.runtime.make_fake_stream( + use_tvm_ffi_env_stream=True) + + if swap_ab: + kernel_a_ptr = b_ptr + kernel_a_sf_ptr = b_sf_ptr + kernel_b_ptr = a_ptr + kernel_b_sf_ptr = a_sf_ptr + else: + kernel_a_ptr = a_ptr + kernel_a_sf_ptr = a_sf_ptr + kernel_b_ptr = b_ptr + kernel_b_sf_ptr = b_sf_ptr + gemm = self.__class__.kernel_class( sf_vec_size, mma_tiler_mn, @@ -521,6 +567,8 @@ def forward( max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1]) + # Note: when tvm_ffi fake stream is used, at least one parameter shoube be tensor type, + # so we make alpha as the cute.Tensor type in the jit func. compiled_gemm = cute.compile( gemm.wrapper, kernel_m, @@ -529,17 +577,18 @@ def forward( kernel_sf_m // 128, kernel_sf_n // 128, sf_k // 4, - 1, + 1, # batch kernel_a_ptr, kernel_b_ptr, kernel_a_sf_ptr, kernel_b_sf_ptr, c_ptr, - alpha_ptr, # Pass alpha as device pointer + alpha_cute_tensor, max_active_clusters, stream, swap_ab, - options=f"--opt-level 2", + options=f"--opt-level 2 --enable-tvm-ffi" + if self.use_tvm_ffi else "--opt-level 2", ) self.__class__.kernel_cache[cache_key] = compiled_gemm @@ -547,21 +596,39 @@ def forward( compiled_gemm = self.__class__.kernel_cache[cache_key] # launch gemm kernel - compiled_gemm( - kernel_m, - kernel_n, - real_k, - kernel_sf_m // 128, - kernel_sf_n // 128, - sf_k // 4, - kernel_a_ptr, - kernel_b_ptr, - kernel_a_sf_ptr, - kernel_b_sf_ptr, - c_ptr, - alpha_ptr, # Pass alpha as device pointer - stream, - ) + if self.use_tvm_ffi: + # call with torch pointer types and no need to pass stream. + compiled_gemm( + kernel_m, + kernel_n, + real_k, + kernel_sf_m // 128, + kernel_sf_n // 128, + sf_k // 4, + kernel_a_tensor.data_ptr(), + kernel_b_tensor.data_ptr(), + kernel_a_sf_tensor.data_ptr(), + kernel_b_sf_tensor.data_ptr(), + c_tensor.data_ptr(), + alpha_tensor, + ) + else: + # call with cute types and need to pass torch stream. + compiled_gemm( + kernel_m, + kernel_n, + real_k, + kernel_sf_m // 128, + kernel_sf_n // 128, + sf_k // 4, + kernel_a_ptr, + kernel_b_ptr, + kernel_a_sf_ptr, + kernel_b_sf_ptr, + c_ptr, + alpha_cute_tensor, + stream, + ) if swap_ab: c_tensor = c_tensor.permute(1, 0) @@ -579,6 +646,7 @@ def cute_dsl_nvfp4_gemm_blackwell( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool = False, + use_tvm_ffi: bool = True, ) -> torch.Tensor: """CuteDSL-based NVFP4 GEMM optimized for Blackwell. @@ -590,6 +658,7 @@ def cute_dsl_nvfp4_gemm_blackwell( alpha: Scaling factor output_dtype: Output data type (must be bfloat16) to_userbuffers: Whether to allocate output from UserBuffers pool + use_tvm_ffi: Whether to use TVM-FFI to call the kernel. Enable this option could help reduce the kernel host launch overhead. Note: This function is primarily used internally by nvfp4_gemm. @@ -606,7 +675,8 @@ def cute_dsl_nvfp4_gemm_blackwell( tuner = AutoTuner.get() - runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers) + runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers, + use_tvm_ffi) inputs = [input, weight, input_scale, weight_scale, alpha] _, best_tactic = tuner.choose_one( "trtllm::cute_dsl_nvfp4_gemm_blackwell", @@ -627,6 +697,7 @@ def _( alpha: torch.Tensor, # Match custom op signature output_dtype: torch.dtype, to_userbuffers: bool = False, + use_tvm_ffi: bool = True, ): # [m, k] shape = list(mat_a.shape) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py index 6b6b427edca..44edab9b3f2 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py @@ -2017,20 +2017,19 @@ def can_implement( @cute.jit def wrapper( self, - m, - n, - k, - sf_m, - sf_n, - sf_k, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + sf_m: cutlass.Int32, + sf_n: cutlass.Int32, + sf_k: cutlass.Int32, l: cutlass.Constexpr, a_ptr: cute.Pointer, b_ptr: cute.Pointer, a_sf_ptr: cute.Pointer, b_sf_ptr: cute.Pointer, c_ptr: cute.Pointer, - alpha: cute. - Pointer, # Device pointer to alpha, will be converted to Tensor + alpha_tensor: cute.Tensor, max_active_clusters: cutlass.Constexpr, current_stream: cuda.CUstream, swap_ab: cutlass.Constexpr = False, @@ -2051,7 +2050,7 @@ def wrapper( a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A. b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B. c_ptr (cute.Pointer): Pointer to the C tensor. - alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally). + alpha_tensor (cute.Tensor): Device tensor to alpha scaling factor. max_active_clusters (cutlass.Constexpr): Maximum number of active clusters. current_stream (cuda.CUstream): CUDA stream for the operation. @@ -2096,9 +2095,6 @@ def wrapper( (32, 4, sf_n, 4, sf_k, l), order=(2, 1, 4, 0, 3, 5), )) - alpha_tensor = cute.make_tensor(alpha, - layout=cute.make_ordered_layout( - (1, ), order=(0, ))) self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor, max_active_clusters, current_stream, epilogue_op) diff --git a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py index a549b52fa4e..21963c679a1 100644 --- a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py +++ b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py @@ -311,15 +311,17 @@ def nvfp4_gemm_perf_test( x_sf_block_list = [x_sf_block] w_sf_block_list = [w_sf_block] + alpha_tensor = torch.tensor([1.0]).cuda() with torch.inference_mode(), autotune(): with nvtx.annotate( f"cute_dsl tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}", color="orange", ): output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell( - x_fp4, w_fp4, x_sf_block, w_sf_block, 1.0, dtype) + x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_tensor, dtype) + from tensorrt_llm._torch.autotuner import AutoTuner + AutoTuner.get().print_statistics() - alpha_tensor = torch.tensor(1.0).cuda() if test_ref: with nvtx.annotate( f"ref tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}", @@ -340,7 +342,7 @@ def nvfp4_gemm_perf_test( w_fp4_list[buffer_idx % workspace_count], x_sf_block_list[buffer_idx % workspace_count], w_sf_block_list[buffer_idx % workspace_count], - 1.0, + alpha_tensor, dtype, ) buffer_idx = buffer_idx + 1 @@ -354,7 +356,7 @@ def nvfp4_gemm_perf_test( w_fp4_list[buffer_idx % workspace_count], x_sf_block_list[buffer_idx % workspace_count], w_sf_block_list[buffer_idx % workspace_count], - 1.0, + alpha_tensor, dtype, ) buffer_idx = buffer_idx + 1 @@ -455,7 +457,7 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize( x, x_sf_global, scaling_vector_size, False) alpha_ref = 1.0 / (w_sf_global * x_sf_global) - alpha_tensor = torch.tensor(alpha_ref, dtype=torch.float32).cuda() + alpha_tensor = torch.tensor([alpha_ref], dtype=torch.float32).cuda() # Reference: Use CUTLASS backend explicitly for reference output with torch.inference_mode(): @@ -744,23 +746,19 @@ def test_fp4_linear_cuda_core(dtype, mnk): if __name__ == "__main__": # m, n, k - fp4_linear_perf_test(torch.bfloat16, 128, 7168, 16384) - fp4_linear_perf_test(torch.bfloat16, 128, 24576, 1536) - fp4_linear_perf_test(torch.bfloat16, 128, 2112, 7168) - fp4_linear_perf_test(torch.bfloat16, 128, 4096, 7168) - fp4_linear_perf_test(torch.bfloat16, 128, 7168, 2048) - - # group-1 test cases - for tokens in [128, 8192]: - nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384) - nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536) - nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168) - nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168) - nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048) - - # group-2 test cases - for m in [128, 256, 512]: - nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168) - nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792) - nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False) - nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664) + nvfp4_gemm_perf_test(torch.bfloat16, 128, 7168, 16384) + + # # group-1 test cases + # for tokens in [128, 8192]: + # nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384) + # nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536) + # nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168) + # nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168) + # nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048) + + # # group-2 test cases + # for m in [128, 256, 512]: + # nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168) + # nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792) + # nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False) + # nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664)