Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ repos:
additional_dependencies:
- tomli
# add ignore words list
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
args: ["-L", "Mor,ans,thirdparty,subtiles", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4
hooks:
Expand Down
516 changes: 514 additions & 2 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

154 changes: 152 additions & 2 deletions tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@

import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, if_generate
from cutlass.pipeline import (CooperativeGroup, PipelineAsync, PipelineOp,
PipelineState)
from cutlass.pipeline import (Agent, CooperativeGroup, PipelineAsync,
PipelineOp, PipelineState, agent_sync)


def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
Expand Down Expand Up @@ -374,3 +374,153 @@ def then_body():
self.producer_acquire(state)

if_generate(is_leader_cta, then_body)


@dataclass(frozen=True)
class PipelineCpAsyncUmma(PipelineAsync):
"""
PipelineCpAsyncUmma is used for LDGSTS (CpAsync) producers and UMMA consumers.

This pipeline is specifically designed for scenarios where:
- Producers use LDGSTS instructions (cp.async) to load data from global to shared memory
- Consumers are UMMA warps that perform MMA operations using the loaded data

Key differences from PipelineAsyncUmma:
- Suitable for gather/permutation operations during load
- Used in this kernel for A and SFA matrices with token-based gather addressing
"""

cta_group: cute.nvgpu.tcgen05.CtaGroup

@staticmethod
def _compute_leading_cta_rank(cta_v_size):
"""
Computes the leading CTA rank.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster())
return cta_rank_in_cluster // cta_v_size * cta_v_size

@staticmethod
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
"""
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
"""
bidx, bidy, _ = cute.arch.block_idx()
mma_coord_vmnk = (
bidx % cute.size(cta_layout_vmnk, mode=[0]),
bidx // cute.size(cta_layout_vmnk, mode=[0]),
bidy,
None,
)
return mma_coord_vmnk[0] == 0

@staticmethod
def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout):
"""
Computes a mask for signaling arrivals to multicasting threadblocks.
"""
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster())
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(
cta_rank_in_cluster)
mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0)
block_in_cluster_coord_vmnk_peer = (
cta_in_cluster_coord_vmnk[0] ^ 1,
*cta_in_cluster_coord_vmnk[1:],
)
mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0)
return mask_self | mask_peer

@staticmethod
def create(
*,
num_stages: int,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
barrier_storage: cute.Pointer = None,
cta_layout_vmnk: Optional[cute.Layout] = None,
defer_sync: bool = False,
enable_cp_async: bool = False,
):
"""Creates and initializes a new PipelineCpAsyncUmma instance.

:param num_stages: Number of buffer stages for this pipeline
:type num_stages: int
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer, optional
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout, optional
:param defer_sync: Whether to defer the sync
:type defer_sync: bool, optional
:param enable_cp_async: Whether to enable cp.async instructions
:type enable_cp_async: bool, optional
:raises ValueError: If barrier_storage is not a cute.Pointer instance
:return: A new PipelineCpAsyncUmma instance configured with the provided parameters
:rtype: PipelineCpAsyncUmma
"""
if not isinstance(barrier_storage, cute.Pointer):
raise ValueError(
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
)

producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread
consumer_type = PipelineOp.TCGen05Mma

producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)

sync_object_full = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8),
num_stages,
producer,
)
sync_object_empty = PipelineAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages,
consumer)

cta_v_size = cute.size(cta_layout_vmnk,
mode=[0]) if cta_layout_vmnk is not None else 1
cta_group = (cute.nvgpu.tcgen05.CtaGroup.ONE if cta_layout_vmnk is None
or cute.size(cta_layout_vmnk, mode=[0]) == 1 else
cute.nvgpu.tcgen05.CtaGroup.TWO)
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
# No mcast mask if we're not using 2CTA tcgen05 MMA
producer_mask = None
consumer_mask = None
else:
# If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA
# We need to get the target cta_rank
producer_mask = PipelineCpAsyncUmma._compute_leading_cta_rank(
cta_v_size)
# consumer needs to get the mask to signal
consumer_mask = PipelineCpAsyncUmma._compute_peer_cta_mask(
cta_layout_vmnk)

if not defer_sync:
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
agent_sync(Agent.ThreadBlock)
else:
agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)

return PipelineCpAsyncUmma(
sync_object_full,
sync_object_empty,
num_stages,
producer_mask,
consumer_mask,
cta_group,
)

def consumer_release(self, state: PipelineState):
"""
UMMA consumer release buffer empty, cta_group needs to be provided.
"""
self.sync_object_empty.arrive(state.index, self.consumer_mask,
self.cta_group)
14 changes: 4 additions & 10 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,22 +273,16 @@ def run_moe_nvfp4(
local_num_experts=self.expert_size_per_partition,
tile_tokens_dim=tile_size,
)
x, x_sf = torch.ops.trtllm.moe_permute(
input=x.view(torch.float4_e2m1fn_x2),
input_sf=x_sf,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
)
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(

x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
alpha=self.quant_scales.fc1_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
global_sf=self.fc2_input_scale,
num_experts=self.num_slots,
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ def fp4_scale_infer_shape(input_shapes: List[List[int]]):
return scale_shape * 2


def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp4 scale tensor.
"""
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
sf_vec_size=16,
is_swizzled_layout=False)
return scale_shape * 2


_enable_piecewise_cuda_graph = True


Expand Down
Loading