Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ nvidia-cutlass-dsl==4.3.1; python_version >= "3.10"
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
partial_json_parser
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
167 changes: 119 additions & 48 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):

def __init__(self,
output_dtype: torch.dtype,
to_userbuffers: bool = False):
to_userbuffers: bool = False,
use_tvm_ffi: bool = True):
super().__init__()

if output_dtype != torch.bfloat16:
Expand All @@ -249,17 +250,19 @@ def __init__(self,
)
self.output_dtype = output_dtype
self.to_userbuffers = to_userbuffers
self.use_tvm_ffi = use_tvm_ffi

def unique_id(self):
return (self.output_dtype, self.to_userbuffers)
return (self.output_dtype, self.to_userbuffers, self.use_tvm_ffi)

def __hash__(self):
return hash((self.output_dtype, self.to_userbuffers))
return hash(
(self.output_dtype, self.to_userbuffers, self.use_tvm_ffi))

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers and self.use_tvm_ffi == other.use_tvm_ffi

def get_valid_tactics(
self,
Expand Down Expand Up @@ -465,51 +468,94 @@ def forward(
f"CuteDSL: weight scale factor size mismatch. "
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")
if alpha_tensor.numel() != 1:
raise ValueError(f"CuteDSL: alpha size mismatch. "
f"Expected 1, got {alpha_tensor.numel()}")

# Reshape to CuteDSL's expected format (just a view, no copy)
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)

a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
cutlass.BFloat16, 16)
# Create pointer to alpha on device
alpha_ptr = self.make_cute_dsl_global_pointer(
alpha_tensor, cutlass.Float32, 4)

# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
if not self.use_tvm_ffi:
a_ptr = self.make_cute_dsl_global_pointer(
a_tensor, cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(
b_tensor, cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(
c_tensor, cutlass.BFloat16, 16)
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)

# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)

cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
use_prefetch)
if swap_ab:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
kernel_m = n
kernel_n = m
kernel_sf_m = sf_n
kernel_sf_n = sf_m

kernel_a_tensor = b_tensor
kernel_a_sf_tensor = b_sf_tensor
kernel_b_tensor = a_tensor
kernel_b_sf_tensor = a_sf_tensor

if not self.use_tvm_ffi:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
else:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr
kernel_m = m
kernel_n = n
kernel_sf_m = sf_m
kernel_sf_n = sf_n

kernel_a_tensor = a_tensor
kernel_a_sf_tensor = a_sf_tensor
kernel_b_tensor = b_tensor
kernel_b_sf_tensor = b_sf_tensor

if not self.use_tvm_ffi:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr

if cache_key not in self.__class__.kernel_cache:
if self.use_tvm_ffi:
a_ptr = self.make_cute_dsl_global_pointer(
a_tensor, cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(
b_tensor, cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(
c_tensor, cutlass.BFloat16, 16)
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)
# make faked stream
stream = cute.runtime.make_fake_stream(
use_tvm_ffi_env_stream=True)

if swap_ab:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
else:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr

gemm = self.__class__.kernel_class(
sf_vec_size,
mma_tiler_mn,
Expand All @@ -521,6 +567,8 @@ def forward(
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])

# Note: when tvm_ffi fake stream is used, at least one parameter shoube be tensor type,
# so we make alpha as the cute.Tensor type in the jit func.
compiled_gemm = cute.compile(
gemm.wrapper,
kernel_m,
Expand All @@ -529,39 +577,58 @@ def forward(
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
1,
1, # batch
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_ptr, # Pass alpha as device pointer
alpha_cute_tensor,
max_active_clusters,
stream,
swap_ab,
options=f"--opt-level 2",
options=f"--opt-level 2 --enable-tvm-ffi"
if self.use_tvm_ffi else "--opt-level 2",
)

self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]

# launch gemm kernel
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_ptr, # Pass alpha as device pointer
stream,
)
if self.use_tvm_ffi:
# call with torch pointer types and no need to pass stream.
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_tensor.data_ptr(),
kernel_b_tensor.data_ptr(),
kernel_a_sf_tensor.data_ptr(),
kernel_b_sf_tensor.data_ptr(),
c_tensor.data_ptr(),
alpha_tensor,
)
else:
# call with cute types and need to pass torch stream.
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_cute_tensor,
stream,
)

if swap_ab:
c_tensor = c_tensor.permute(1, 0)
Expand All @@ -579,6 +646,7 @@ def cute_dsl_nvfp4_gemm_blackwell(
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
use_tvm_ffi: bool = True,
) -> torch.Tensor:
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.

Expand All @@ -590,6 +658,7 @@ def cute_dsl_nvfp4_gemm_blackwell(
alpha: Scaling factor
output_dtype: Output data type (must be bfloat16)
to_userbuffers: Whether to allocate output from UserBuffers pool
use_tvm_ffi: Whether to use TVM-FFI to call the kernel. Enable this option could help reduce the kernel host launch overhead.

Note:
This function is primarily used internally by nvfp4_gemm.
Expand All @@ -606,7 +675,8 @@ def cute_dsl_nvfp4_gemm_blackwell(

tuner = AutoTuner.get()

runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers)
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers,
use_tvm_ffi)
inputs = [input, weight, input_scale, weight_scale, alpha]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
Expand All @@ -627,6 +697,7 @@ def _(
alpha: torch.Tensor, # Match custom op signature
output_dtype: torch.dtype,
to_userbuffers: bool = False,
use_tvm_ffi: bool = True,
):
# [m, k]
shape = list(mat_a.shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2017,20 +2017,19 @@ def can_implement(
@cute.jit
def wrapper(
self,
m,
n,
k,
sf_m,
sf_n,
sf_k,
m: cutlass.Int32,
n: cutlass.Int32,
k: cutlass.Int32,
sf_m: cutlass.Int32,
sf_n: cutlass.Int32,
sf_k: cutlass.Int32,
l: cutlass.Constexpr,
a_ptr: cute.Pointer,
b_ptr: cute.Pointer,
a_sf_ptr: cute.Pointer,
b_sf_ptr: cute.Pointer,
c_ptr: cute.Pointer,
alpha: cute.
Pointer, # Device pointer to alpha, will be converted to Tensor
alpha_tensor: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
current_stream: cuda.CUstream,
swap_ab: cutlass.Constexpr = False,
Expand All @@ -2051,7 +2050,7 @@ def wrapper(
a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A.
b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B.
c_ptr (cute.Pointer): Pointer to the C tensor.
alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally).
alpha_tensor (cute.Tensor): Device tensor to alpha scaling factor.
max_active_clusters (cutlass.Constexpr): Maximum number of active
clusters.
current_stream (cuda.CUstream): CUDA stream for the operation.
Expand Down Expand Up @@ -2096,9 +2095,6 @@ def wrapper(
(32, 4, sf_n, 4, sf_k, l),
order=(2, 1, 4, 0, 3, 5),
))
alpha_tensor = cute.make_tensor(alpha,
layout=cute.make_ordered_layout(
(1, ), order=(0, )))

self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor,
max_active_clusters, current_stream, epilogue_op)
Expand Down
Loading