Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 43 additions & 36 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ def get_valid_tactics(

# full shamoo
mma_tiler_mn_candidates = [
(256, 128),
(128, 64),
(256, 64),
(128, 128),
(256, 128),
(128, 192),
(256, 192),
(128, 256),
(256, 256),
(256, 64),
(128, 64),
]
cluster_shape_mn_candidates = [
(1, 1),
Expand All @@ -100,37 +102,38 @@ def get_valid_tactics(
(4, 4),
]
swap_ab_candidates = [True, False]
use_prefetch_candidates = [True, False]

valid_tactics = []
for swap_ab in swap_ab_candidates:
for mma_tiler_mn in mma_tiler_mn_candidates:
for cluster_shape_mn in cluster_shape_mn_candidates:
if swap_ab:
c_major = "m"
kernel_m = n
kernel_n = m
else:
c_major = "n"
kernel_m = m
kernel_n = n

if self.__class__.kernel_class.can_implement(
cutlass.Float4E2M1FN, # ab_dtype,
cutlass.Float8E4M3FN, # sf_dtype
sf_vec_size, # sf_vec_size,
cutlass.BFloat16, # c_dtype,
mma_tiler_mn,
cluster_shape_mn,
kernel_m,
kernel_n,
real_k,
batch_size,
a_major,
b_major,
c_major,
):
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, swap_ab))
for mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch in itertools.product(
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
swap_ab_candidates, use_prefetch_candidates):
if swap_ab:
c_major = "m"
kernel_m = n
kernel_n = m
else:
c_major = "n"
kernel_m = m
kernel_n = n

if self.__class__.kernel_class.can_implement(
cutlass.Float4E2M1FN, # ab_dtype,
cutlass.Float8E4M3FN, # sf_dtype
sf_vec_size, # sf_vec_size,
cutlass.BFloat16, # c_dtype,
mma_tiler_mn,
cluster_shape_mn,
kernel_m,
kernel_n,
real_k,
batch_size,
a_major,
b_major,
c_major,
):
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch))

return valid_tactics

Expand Down Expand Up @@ -159,21 +162,22 @@ def forward(
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
inputs[4]: Alpha scaling factor. dtype: float32.
inputs[5]: Output dtype, expected to be torch.bfloat16.
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch).

Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
sf_vec_size = 16

if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn, swap_ab = tactic
mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch = tactic
else:
# fallback to default tactic
mma_tiler_mn, cluster_shape_mn, swap_ab = [
mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch = [
(128, 128),
(1, 1),
False,
False,
]

a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
Expand Down Expand Up @@ -209,7 +213,8 @@ def forward(
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)

cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab)
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
use_prefetch)
if swap_ab:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
Expand All @@ -234,6 +239,7 @@ def forward(
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
use_prefetch,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
Expand All @@ -258,6 +264,7 @@ def forward(
max_active_clusters,
stream,
swap_ab,
options=f"--opt-level 2",
)

self.__class__.kernel_cache[cache_key] = compiled_gemm
Expand Down
Loading