From af779ccfa870019c4f1527e89d07b9d6372a22e9 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:42:32 +0800 Subject: [PATCH 01/20] draft: use all reduce for rms norm in attention module Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../_torch/models/modeling_minimaxm2.py | 71 ++++++++++++++++--- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 8e6b498e9d4f..db0233251795 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -20,9 +20,11 @@ from transformers import PretrainedConfig from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..distributed import AllReduce from ..models.modeling_utils import ModelConfig from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer @@ -114,6 +116,40 @@ def forward( return final_hidden_states +# We use all_reduce across all tp gpus to get the rms norm variance sum +class MiniMaxRMSNorm(nn.Module): + def __init__( + self, *, hidden_size: int, eps: float, mapping: Mapping, dtype: torch.dtype = torch.bfloat16 + ): + super().__init__() + self.mapping = mapping + # for attention input, tp_size * hidden_size = head_num * head_size + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, requires_grad=False)) + self.hidden_size = hidden_size + self.eps = eps + self.dtype = dtype + self.all_reduce = AllReduce(mapping=self.mapping) + + # TODO: add load weights method + def load_weights(self, weights: Dict): + assert len(weights) == 1 + slice_width = self.hidden_size + slice_start = self.mapping.tp_rank * slice_width + slice_end = slice_start + slice_width + self.weight.copy_(weights["weight"][slice_start:slice_end].to(self.weight.dtype)) + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = self.all_reduce(variance) / self.mapping.tp_size + + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = self.weight * hidden_states.to(input_dtype) + return hidden_states + + # It's a little bit tricky to implement special qk norm # because rms dim is hidden_size * num_heads, not hidden_size, after qkv linear, # the result size is hidden_size * num_heads / tp_size. @@ -148,17 +184,30 @@ def __init__( dtype=config.torch_dtype, config=model_config, ) - - self.q_norm = RMSNorm( - hidden_size=self.q_size * self.tp_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype, - ) - self.k_norm = RMSNorm( - hidden_size=self.kv_size * self.tp_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype, - ) + if self.qkv_proj.mapping.tp_size > 1: + self.q_norm = MiniMaxRMSNorm( + hidden_size=self.q_size, + eps=config.rms_norm_eps, + mapping=self.qkv_proj.mapping, + dtype=config.torch_dtype, + ) + self.k_norm = MiniMaxRMSNorm( + hidden_size=self.kv_size, + eps=config.rms_norm_eps, + mapping=self.qkv_proj.mapping, + dtype=config.torch_dtype, + ) + else: + self.q_norm = RMSNorm( + hidden_size=self.q_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + self.k_norm = RMSNorm( + hidden_size=self.kv_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) def apply_qk_norm(self, q, k): if self.qkv_proj.mapping.tp_size > 1: From 24f495af3a23c03fed93d5b5b6a793b203200ef8 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 30 Jan 2026 13:44:10 +0800 Subject: [PATCH 02/20] chore: use nccl as all reduce backend Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../_torch/models/modeling_minimaxm2.py | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index db0233251795..de0dc21c02e6 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -19,7 +19,7 @@ from torch import nn from transformers import PretrainedConfig -from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata @@ -124,11 +124,11 @@ def __init__( super().__init__() self.mapping = mapping # for attention input, tp_size * hidden_size = head_num * head_size - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, requires_grad=False)) + self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype), requires_grad=False) self.hidden_size = hidden_size self.eps = eps self.dtype = dtype - self.all_reduce = AllReduce(mapping=self.mapping) + self.all_reduce = AllReduce(mapping=self.mapping, strategy=AllReduceStrategy.NCCL) # TODO: add load weights method def load_weights(self, weights: Dict): @@ -136,14 +136,14 @@ def load_weights(self, weights: Dict): slice_width = self.hidden_size slice_start = self.mapping.tp_rank * slice_width slice_end = slice_start + slice_width - self.weight.copy_(weights["weight"][slice_start:slice_end].to(self.weight.dtype)) + self.weight.copy_(weights[0]["weight"][slice_start:slice_end].to(self.weight.dtype)) def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - variance = self.all_reduce(variance) / self.mapping.tp_size + variance = hidden_states.pow(2).mean(-1, keepdim=True) / self.mapping.tp_size + variance = self.all_reduce(variance) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = self.weight * hidden_states.to(input_dtype) @@ -210,23 +210,25 @@ def __init__( ) def apply_qk_norm(self, q, k): - if self.qkv_proj.mapping.tp_size > 1: - # collect q and k from all gpus - from ..distributed import allgather - - temp_q = allgather(q, self.qkv_proj.mapping) - temp_k = allgather(k, self.qkv_proj.mapping) - temp_q = self.q_norm(temp_q) - temp_k = self.k_norm(temp_k) - q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( - -1, self.q_size - ) - k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( - -1, self.kv_size - ) - else: - q = self.q_norm(q) - k = self.k_norm(k) + q = self.q_norm(q) + k = self.k_norm(k) + # if self.qkv_proj.mapping.tp_size > 1: + # # collect q and k from all gpus + # from ..distributed import allgather + + # temp_q = allgather(q, self.qkv_proj.mapping) + # temp_k = allgather(k, self.qkv_proj.mapping) + # temp_q = self.q_norm(temp_q) + # temp_k = self.k_norm(temp_k) + # q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( + # -1, self.q_size + # ) + # k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( + # -1, self.kv_size + # ) + # else: + # q = self.q_norm(q) + # k = self.k_norm(k) return q, k From f9ef53687bc65835ec8c52edb267ca9423f556dd Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 30 Jan 2026 13:45:15 +0800 Subject: [PATCH 03/20] draft: add new reduce kernel file Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 239 ++++++++++++++++++ .../MiniMaxReduceRMSKernel.h | 61 +++++ 2 files changed, 300 insertions(+) create mode 100644 cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu create mode 100644 cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu new file mode 100644 index 000000000000..c2c7fd9dfc4b --- /dev/null +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -0,0 +1,239 @@ +#include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include "tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h" +#include "tensorrt_llm/kernels/quantization.cuh" +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::minimax_ar +{ +template +struct LamportComm +{ + __device__ __forceinline__ LamportComm(void** workspace, int rank) + { + counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; + flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[2]; + clear_ptr = &reinterpret_cast(workspace[NRanks * 3 + 1])[0]; + flag_value = *flag_ptr; + auto comm_size = reinterpret_cast(workspace[NRanks * 3 + 1])[1]; + clear_size = *clear_ptr; + int data_offset = flag_value % 3; + int clear_offset = (flag_value + 2) % 3; + for (int r = 0; r < NRanks; ++r) + { + data_bufs[r] = reinterpret_cast(workspace[2 * NRanks + r]) + data_offset * comm_size; + } + clear_buf = reinterpret_cast(workspace[2 * NRanks + rank]) + clear_offset * comm_size; + __syncthreads(); + if (threadIdx.x == 0) + { + atomicAdd(counter_ptr, 1); + } + } + + __device__ __forceinline__ void update(int64_t new_clear_size) + { + if (blockIdx.x == 0 && threadIdx.x == 0) + { + while (*reinterpret_cast(counter_ptr) != gridDim.x) + { + } + *flag_ptr = (flag_value + 1) % 3; + *clear_ptr = new_clear_size; + *counter_ptr = 0; + } + } + + int* counter_ptr; + int* flag_ptr; + int64_t* clear_ptr; + uint8_t* data_bufs[NRanks]; + uint8_t* clear_buf; + int64_t clear_size; + int flag_value; +}; + +__device__ __forceinline__ bool is_neg_zero(float v) +{ + return *reinterpret_cast(&v) == 0x80000000; +} + +__device__ __forceinline__ bool is_neg_zero(float4 v) +{ + return is_neg_zero(v.x) || is_neg_zero(v.y) || is_neg_zero(v.z) || is_neg_zero(v.w); +} + +__device__ __forceinline__ float4 get_neg_zero() +{ + float4 vec; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + reinterpret_cast(&vec)[i] = 0x80000000; + } + return vec; +} + +// __device__ __forceinline__ float4 ld_global_volatile(float4* addr) +// { +// float4 val; +// asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];" +// : "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w) +// : "l"(addr)); +// return val; +// } + +__device__ __forceinline__ float ld_global_volatile(float* addr) +{ + float val; + asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(val) : "l"(addr)); + return val; +} + +template +class IndexHelper +{ +public: + __device__ __forceinline__ IndexHelper(MiniMaxReduceRMSParams const& params) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + cg::grid_group grid = cg::this_grid(); + token_id = grid.cluster_rank(); + access_id_in_token = cluster.thread_rank(); + token_stride = grid.num_clusters(); +#else + token_id = blockIdx.x; + access_id_in_token = threadIdx.x; + token_stride = gridDim.x; +#endif + access_id = token_id * params.hidden_dim / kElemsPerAccess + access_id_in_token; + access_stride = token_stride * params.hidden_dim / kElemsPerAccess; + tot_access = params.size / kElemsPerAccess; + } + + int token_id; + int access_id_in_token; + int token_stride; + int access_id; + int access_stride; + int tot_access; +}; + +/** +* this kernel is used to for minimax attention module +* input tensor [total_tokens, hidden_dim / tp_size], fp32 +* rms weight [hidden_dim / tp_size], bf16 +step 1: reduce from single rank to get the variance sum (reduce(input^2, dim=-1)) +step 2: reduce from all ranks to get the variance sum (all_reduce(variance_sum)) +step 3: calculate the rms norm (input * rsqrt(variance + eps)) +in this case, max hidden_dim is 6144, for each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) threads +so we can assume cluster size is 1 (tp_size >= 2) + */ +template +__global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMaxReduceRMSParams params) +{ + IndexHelper index_helper(params); + int token_id = index_helper.token_id; + int access_id_in_token = index_helper.access_id_in_token; + int token_stride = index_helper.token_stride; + int access_id = index_helper.access_id; + int access_stride = index_helper.access_stride; + int tot_access = index_helper.tot_access; + float4 clear_vec = get_neg_zero(); + // FusedOp fused_op(params, access_id, access_id_in_token); + + alignas(16) float vals[4]; + float sum_variance = 0.F; + *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[access_id]; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (is_neg_zero(vals[i])) + { + vals[i] = 0.F; + } + sum_variance += vals[i] * vals[i]; + } + // step 1: reduce from single rank to get the variance sum + tensorrt_llm::common::blockReduceSumV2(&sum_variance); + + // step 2: reduce from all ranks to get the variance sum + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + if constexpr (!TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + LamportComm comm(params.workspace, params.rank); + int clear_access = comm.clear_size / kElemsPerAccess; + // be careful, we only use float to load and store variance sum + // but we use float4 to load input tensor + // constexpr int StrideGap = kElemsPerAccess; + // Push data to other ranks + for (int r = 0; r < NRanks; ++r) + { + reinterpret_cast(comm.data_bufs[r])[params.rank * tot_access + access_id] = (sum_variance); + } + + for (int idx = access_id; idx < clear_access; idx += access_stride) + { + // Clear comm buffer that previous kernel used + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; + } + + // Load data from other ranks + bool done = false; + float4 vars_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[r])[params.rank * tot_access + access_id]); + done &= !is_neg_zero(vars_all_ranks[r]); + } + } + sum_variance = 0.F; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + sum_variance += vars_all_ranks[r]; + } + + // step 3: calculate the rms norm (input * rsqrt(variance + eps)) + + // load norm weight + // TODO: correct the access_id_in_token + __nv_bfloat16 norm_weight = reinterpret_cast<__nv_bfloat16*>(params.rms_gamma)[access_id_in_token]; + +#pragma unroll + for (int i = 0; i < 4; ++i) + { + vals[i] = vals[i] * rsqrtf(sum_variance + params.rms_eps) * static_cast(norm_weight); + } + + // step 4: store the rms norm + reinterpret_cast(params.rms_norm_out)[access_id] = *reinterpret_cast(vals); + + comm.update(params.size * NRanks); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if constexpr (TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + +} // namespace kernels::minimax_ar + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h new file mode 100644 index 000000000000..2dba6059147f --- /dev/null +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -0,0 +1,61 @@ +#pragma once +#include "tensorrt_llm/common/assert.h" +#include +#include +#include + +#include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/runtime/ipcUtils.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::minimax_ar +{ +template +struct ElemsPerAccess; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 8; +}; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 8; +}; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 4; +}; + +template +static constexpr int kElemsPerAccess = ElemsPerAccess::value; + +struct MiniMaxReduceRMSParams +{ + int nranks{}; + int rank{}; + nvinfer1::DataType dtype; + int size{}; + int hidden_dim{}; + void** workspace{}; + void* allreduce_in{}; + void* rms_norm_out{}; + void* rms_gamma{}; + float rms_eps{}; + float* scale_factor{}; + cudaStream_t stream{}; + bool trigger_completion_at_end = true; +}; + +void allreduce_fusion_op(AllReduceFusionParams const& params); + +} // namespace kernels::minimax_ar + +TRTLLM_NAMESPACE_END From 462f88214d0023869beec5fcabb2c390db9ff5fb Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:47:50 +0800 Subject: [PATCH 04/20] chore: add pytorch wrapper Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 224 +++++++++++++----- .../MiniMaxReduceRMSKernel.h | 3 +- cpp/tensorrt_llm/thop/allreduceOp.cpp | 37 +++ .../_torch/models/modeling_minimaxm2.py | 15 +- 4 files changed, 220 insertions(+), 59 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index c2c7fd9dfc4b..ce908e9d4164 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -9,6 +9,9 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels::minimax_ar { +namespace +{ // anonymous namespace + template struct LamportComm { @@ -144,26 +147,10 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa int access_id = index_helper.access_id; int access_stride = index_helper.access_stride; int tot_access = index_helper.tot_access; + int tot_tokens = params.size / params.hidden_dim; float4 clear_vec = get_neg_zero(); // FusedOp fused_op(params, access_id, access_id_in_token); - alignas(16) float vals[4]; - float sum_variance = 0.F; - *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[access_id]; -#pragma unroll - for (int i = 0; i < 4; ++i) - { - if (is_neg_zero(vals[i])) - { - vals[i] = 0.F; - } - sum_variance += vals[i] * vals[i]; - } - // step 1: reduce from single rank to get the variance sum - tensorrt_llm::common::blockReduceSumV2(&sum_variance); - - // step 2: reduce from all ranks to get the variance sum - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); if constexpr (!TriggerCompletionAtEnd) @@ -173,59 +160,82 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa #endif LamportComm comm(params.workspace, params.rank); int clear_access = comm.clear_size / kElemsPerAccess; - // be careful, we only use float to load and store variance sum - // but we use float4 to load input tensor - // constexpr int StrideGap = kElemsPerAccess; - // Push data to other ranks - for (int r = 0; r < NRanks; ++r) + for (int idx = access_id; idx < tot_access; idx += access_stride, token_id += token_stride) { - reinterpret_cast(comm.data_bufs[r])[params.rank * tot_access + access_id] = (sum_variance); - } + alignas(16) float vals[4]; + float sum_variance = 0.F; + *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[access_id]; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (is_neg_zero(vals[i])) + { + vals[i] = 0.F; + } + sum_variance += vals[i] * vals[i]; + } + // step 1: reduce from single rank to get the variance sum + tensorrt_llm::common::blockReduceSumV2(&sum_variance); - for (int idx = access_id; idx < clear_access; idx += access_stride) - { - // Clear comm buffer that previous kernel used - reinterpret_cast(comm.clear_buf)[idx] = clear_vec; - } + // step 2: reduce from all ranks to get the variance sum + // be careful, we only use float to load and store variance sum + // but we use float4 to load input tensor + // constexpr int StrideGap = kElemsPerAccess; + // Push data to other ranks + // we only need the first thread to push data to other ranks + if (threadIdx.x == 0) + { + for (int r = 0; r < NRanks; ++r) + { + // temp data buffer [nranks, total_tokens, 1] + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_tokens) + token_id] = (sum_variance); + } + } - // Load data from other ranks - bool done = false; - float4 vars_all_ranks[NRanks]; - while (!done) - { - done = true; + // Load data from other ranks + bool done = false; + float vars_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_tokens) + token_id]); + done &= !is_neg_zero(vars_all_ranks[r]); + } + } + sum_variance = 0.F; #pragma unroll for (int r = 0; r < NRanks; ++r) { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[r])[params.rank * tot_access + access_id]); - done &= !is_neg_zero(vars_all_ranks[r]); + sum_variance += vars_all_ranks[r]; } - } - sum_variance = 0.F; -#pragma unroll - for (int r = 0; r < NRanks; ++r) - { - sum_variance += vars_all_ranks[r]; - } - // step 3: calculate the rms norm (input * rsqrt(variance + eps)) + // step 3: calculate the rms norm (input * rsqrt(variance + eps)) - // load norm weight - // TODO: correct the access_id_in_token - __nv_bfloat16 norm_weight = reinterpret_cast<__nv_bfloat16*>(params.rms_gamma)[access_id_in_token]; + // load norm weight + // TODO: correct the access_id_in_token + __nv_bfloat16 norm_weight[4]; + *reinterpret_cast<__nv_bfloat164*>(norm_weight) + = reinterpret_cast<__nv_bfloat164*>(params.rms_gamma)[access_id_in_token]; #pragma unroll - for (int i = 0; i < 4; ++i) + for (int i = 0; i < 4; ++i) + { + vals[i] = vals[i] * rsqrtf(sum_variance + params.rms_eps) * static_cast(norm_weight[i]); + } + + // step 4: store the rms norm + reinterpret_cast(params.rms_norm_out)[access_id] = *reinterpret_cast(vals); + } + for (int idx = access_id; idx < clear_access; idx += access_stride) { - vals[i] = vals[i] * rsqrtf(sum_variance + params.rms_eps) * static_cast(norm_weight); + // Clear comm buffer that previous kernel used + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - - // step 4: store the rms norm - reinterpret_cast(params.rms_norm_out)[access_id] = *reinterpret_cast(vals); - comm.update(params.size * NRanks); - #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { @@ -234,6 +244,108 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa #endif } +int get_sm_count() +{ + static int sm_count = 0; + if (sm_count == 0) + { + int device_id; + TLLM_CUDA_CHECK(cudaGetDevice(&device_id)); + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_id); + sm_count = device_prop.multiProcessorCount; + } + return sm_count; +} + +template +void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) +{ + TLLM_CHECK(params.size % params.hidden_dim == 0); + TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); + static int SM = tensorrt_llm::common::getSMVersion(); + int token_num = params.size / params.hidden_dim; + // for current problem size, we only need one cluster + int sm_count = get_sm_count(); + int cluster_size = 1; + int cluster_num = token_num; + int threads_per_token = params.hidden_dim / kElemsPerAccess; + int block_size = threads_per_token; + int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; + + cudaLaunchConfig_t cfg; + cfg.gridDim = grid_size; + cfg.blockDim = block_size; + cfg.dynamicSmemBytes = 0; + cfg.stream = params.stream; + + cudaLaunchAttribute attribute[2]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0; + attribute[1].id = cudaLaunchAttributeClusterDimension; + attribute[1].val.clusterDim.x = cluster_size; + attribute[1].val.clusterDim.y = 1; + attribute[1].val.clusterDim.z = 1; + cfg.attrs = attribute; + cfg.numAttrs = SM >= 90 ? 2 : 0; + + bool trigger_completion_at_end = params.trigger_completion_at_end; + if (trigger_completion_at_end) + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport, params)); + } + else + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport, params)); + } +} + +template +void dispatch_dtype(MiniMaxReduceRMSParams const& params) +{ + if (params.dtype == nvinfer1::DataType::kHALF) + { + minimax_reduce_rms_kernel_launcher(params); + } + else if (params.dtype == nvinfer1::DataType::kBF16) + { + minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); + } + else if (params.dtype == nvinfer1::DataType::kFLOAT) + { + minimax_reduce_rms_kernel_launcher(params); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported data type for minimax_reduce_rms_op"); + } +} +} // namespace + +void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) +{ + if (params.nranks == 2) + { + dispatch_dtype<2>(params); + } + else if (params.nranks == 4) + { + dispatch_dtype<4>(params); + } + else if (params.nranks == 8) + { + dispatch_dtype<8>(params); + } + else if (params.nranks == 16) + { + dispatch_dtype<16>(params); + } + else + { + TLLM_CHECK_WITH_INFO(false, "minimax_reduce_rms_op: unsupported ranks number!"); + } +} + } // namespace kernels::minimax_ar TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h index 2dba6059147f..67df39ef6ef5 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -49,12 +49,11 @@ struct MiniMaxReduceRMSParams void* rms_norm_out{}; void* rms_gamma{}; float rms_eps{}; - float* scale_factor{}; cudaStream_t stream{}; bool trigger_completion_at_end = true; }; -void allreduce_fusion_op(AllReduceFusionParams const& params); +void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params); } // namespace kernels::minimax_ar diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index e5d875fb52e5..87b61def8b32 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/common/nvmlWrapper.h" #include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h" #include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" #include "tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h" @@ -1822,6 +1823,32 @@ std::vector mnnvlFusionAllReduce(torch::Tensor& input, torch::opt return {output, residualOut}; } +torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor const& norm_weight, + torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps, + bool const trigger_completion_at_end_) +{ + auto allreduce_params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); + + allreduce_params.nranks = static_cast(nranks); + allreduce_params.rank = static_cast(rank); + allreduce_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); + allreduce_params.size = static_cast(input.numel()); + allreduce_params.hidden_dim = static_cast(input.size(-1)); + allreduce_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); + allreduce_params.allreduce_in = input.data_ptr(); + // allreduce_params.rms_norm_out = nullptr; + allreduce_params.rms_gamma = norm_weight.data_ptr(); + allreduce_params.rms_eps = static_cast(eps); + allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + torch::Tensor rms_norm_out = torch::empty_like(input); + allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr(); + + tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(allreduce_params); + + return rms_norm_out; +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -1886,6 +1913,15 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "int nranks," "float eps) -> Tensor[]"); m.def("preallocate_nccl_window_buffer(Tensor input, int[] group, int count) -> ()"); + m.def( + "minimax_allreduce_rms(" + "Tensor input," + "Tensor norm_weight," + "Tensor workspace," + "int rank," + "int nranks," + "float eps," + "bool trigger_completion_at_end) -> Tensor"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) @@ -1896,6 +1932,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_allreduce", &tensorrt_llm::torch_ext::moe_allreduce); m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce); m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer); + m.impl("minimax_allreduce_rms", &tensorrt_llm::torch_ext::minimax_allreduce_rms); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index de0dc21c02e6..602a7b9e24d2 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -139,6 +139,7 @@ def load_weights(self, weights: Dict): self.weight.copy_(weights[0]["weight"][slice_start:slice_end].to(self.weight.dtype)) def forward(self, hidden_states: torch.Tensor): + """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -147,7 +148,19 @@ def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = self.weight * hidden_states.to(input_dtype) - return hidden_states + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + rms_norm_out = torch.ops.trtllm.minimax_allreduce_rms( + hidden_states, + self.weight, + self.workspace, + self.mapping.tp_rank, + self.mapping.tp_size, + self.eps, + False, + ) + return rms_norm_out.to(input_dtype) # It's a little bit tricky to implement special qk norm From 83c35033f9d6ec38d34a5f77fa873656e0c1a992 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 5 Feb 2026 15:47:55 +0800 Subject: [PATCH 05/20] test: add unit test case Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- tensorrt_llm/_torch/distributed/ops.py | 16 ++++ .../_torch/multi_gpu/test_allreduce.py | 83 ++++++++++++++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index dd71bb00ec9e..98cc3a4fd367 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1178,3 +1178,19 @@ def all_to_all_5d( gathered_heads = heads * world_size return out.reshape(batch, sharded_seq, qkv_count, gathered_heads, head_dim) + + +class MiniMaxAllReduceRMS(nn.Module): + + def __init__(self, mapping: Mapping): + super().__init__() + self.mapping = mapping + self.workspace = get_allreduce_workspace(self.mapping) + + def forward(self, input: torch.Tensor, rms_weights: torch.Tensor, + eps: float): + return torch.ops.trtllm.minimax_allreduce_rms(input, rms_weights, + self.workspace, + self.mapping.tp_rank, + self.mapping.tp_size, eps, + False) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index a3ba58faaed4..070af052c080 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -27,7 +27,8 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, AllReduceStrategy, - MoEAllReduce, MoEAllReduceParams) + MiniMaxAllReduceRMS, MoEAllReduce, + MoEAllReduceParams) from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm.mapping import Mapping @@ -697,3 +698,83 @@ def test_moe_finalize_allreduce_no_residual(mpi_pool_executor): ) for r in results: assert r is True + + +@torch.inference_mode() +def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, + tensor_parallel_rank: int, + rms_weights: torch.Tensor, eps: float): + torch.manual_seed(42) + + total_tokens = input.shape[0] + origin_dtype = input.dtype + + input = input.cuda() + rms_weights = rms_weights.cuda() + + input = input.reshape(total_tokens, tensor_parallel_size, + -1).to(torch.float32) + rms_weights = rms_weights.reshape(tensor_parallel_size, + -1).to(torch.float32) + rank_input = input[:, tensor_parallel_rank, :] + rank_rms_weights = rms_weights[tensor_parallel_rank, :] + # firstly, calculate the reference output for each rank + ref_output = rms_norm(input, rms_weights, eps) + ref_output = ref_output.reshape(total_tokens, tensor_parallel_size, + -1).to(origin_dtype) + ref_output = ref_output[:, tensor_parallel_rank, :] + + # then, calculate the minimax allreduce output + minimax_allreduce_rms = MiniMaxAllReduceRMS(mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + )) + minimax_output = minimax_allreduce_rms( + input=rank_input, + rms_weights=rank_rms_weights, + eps=eps, + ) + + # finally, verify the results + torch.testing.assert_close(minimax_output, ref_output, rtol=0.2, atol=0.2) + + return rank_input + + +@torch.inference_mode() +def run_minimax_allreduce_rms_single_rank(tensor_parallel_size, + single_rank_forward_func, input, + rms_weights, eps): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + single_rank_forward_func(input, tensor_parallel_size, rank, rms_weights, + eps) + except Exception: + traceback.print_exc() + raise + return True + + +@torch.inference_mode() +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_minimax_allreduce_rms(mpi_pool_executor): + torch.manual_seed(42) + + seq_len = 16 + hidden_size = 7168 + dtype = torch.bfloat16 + tensor_parallel_size = mpi_pool_executor.num_workers + + token_input = torch.randn((seq_len, hidden_size), dtype=dtype) + rms_weights = torch.randn((hidden_size, ), dtype=dtype, device="cuda") + eps = 1e-5 + + results = mpi_pool_executor.map( + run_minimax_allreduce_rms_single_rank, + *zip(*[(tensor_parallel_size, run_minimax_allreduce_rms_op, token_input, + rms_weights, eps)] * tensor_parallel_size), + ) + for r in results: + assert r is True From 5a32a80525527aa7f4571e67a27d10640b11174a Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 12 Feb 2026 14:20:28 +0800 Subject: [PATCH 06/20] fix: bug fix for unit test Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../communicationKernels/MiniMaxReduceRMSKernel.cu | 7 ++++--- tensorrt_llm/_torch/distributed/__init__.py | 5 +++-- tensorrt_llm/_torch/models/modeling_minimaxm2.py | 14 ++++---------- tests/unittest/_torch/multi_gpu/test_allreduce.py | 10 +++++----- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index ce908e9d4164..e505e4744bd8 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -164,7 +164,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa { alignas(16) float vals[4]; float sum_variance = 0.F; - *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[access_id]; + *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[idx]; #pragma unroll for (int i = 0; i < 4; ++i) { @@ -224,11 +224,12 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa #pragma unroll for (int i = 0; i < 4; ++i) { - vals[i] = vals[i] * rsqrtf(sum_variance + params.rms_eps) * static_cast(norm_weight[i]); + vals[i] = vals[i] * rsqrtf((sum_variance / static_cast(params.hidden_dim) / NRanks) + params.rms_eps) + * static_cast(norm_weight[i]); } // step 4: store the rms norm - reinterpret_cast(params.rms_norm_out)[access_id] = *reinterpret_cast(vals); + reinterpret_cast(params.rms_norm_out)[idx] = *reinterpret_cast(vals); } for (int idx = access_id; idx < clear_access; idx += access_stride) { diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 5e18d0d7b77a..9f42f9dd676c 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -3,8 +3,8 @@ from .communicator import Distributed, MPIDist, TorchDist from .moe_alltoall import MoeAlltoAll from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, - HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams, - all_to_all_4d, all_to_all_5d, allgather, alltoall_helix, + HelixAllToAllNative, MiniMaxAllReduceRMS, MoEAllReduce, MoEAllReduceParams, + all_to_all_4d, all_to_all_5d, allgather, alltoall_helix, cp_allgather, reducescatter, userbuffers_allreduce_finalize) __all__ = [ @@ -22,6 +22,7 @@ "HelixAllToAllNative", "MoEAllReduce", "MoEAllReduceParams", + "MiniMaxAllReduceRMS", "MoeAlltoAll", "TorchDist", "MPIDist", diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 602a7b9e24d2..3f3bbd5eb7a9 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -24,7 +24,7 @@ from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams -from ..distributed import AllReduce +from ..distributed import AllReduce, MiniMaxAllReduceRMS from ..models.modeling_utils import ModelConfig from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer @@ -130,6 +130,8 @@ def __init__( self.dtype = dtype self.all_reduce = AllReduce(mapping=self.mapping, strategy=AllReduceStrategy.NCCL) + self.minimax_all_reduce_rms = MiniMaxAllReduceRMS(mapping=self.mapping) + # TODO: add load weights method def load_weights(self, weights: Dict): assert len(weights) == 1 @@ -151,15 +153,7 @@ def forward(self, hidden_states: torch.Tensor): """ input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - rms_norm_out = torch.ops.trtllm.minimax_allreduce_rms( - hidden_states, - self.weight, - self.workspace, - self.mapping.tp_rank, - self.mapping.tp_size, - self.eps, - False, - ) + rms_norm_out = self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) return rms_norm_out.to(input_dtype) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index 070af052c080..e41131382a79 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -716,8 +716,8 @@ def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, -1).to(torch.float32) rms_weights = rms_weights.reshape(tensor_parallel_size, -1).to(torch.float32) - rank_input = input[:, tensor_parallel_rank, :] - rank_rms_weights = rms_weights[tensor_parallel_rank, :] + rank_input = input[:, tensor_parallel_rank, :].contiguous() + rank_rms_weights = rms_weights[tensor_parallel_rank, :].contiguous() # firstly, calculate the reference output for each rank ref_output = rms_norm(input, rms_weights, eps) ref_output = ref_output.reshape(total_tokens, tensor_parallel_size, @@ -732,9 +732,9 @@ def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, )) minimax_output = minimax_allreduce_rms( input=rank_input, - rms_weights=rank_rms_weights, + rms_weights=rank_rms_weights.to(torch.bfloat16), eps=eps, - ) + ).to(origin_dtype) # finally, verify the results torch.testing.assert_close(minimax_output, ref_output, rtol=0.2, atol=0.2) @@ -763,7 +763,7 @@ def test_minimax_allreduce_rms(mpi_pool_executor): torch.manual_seed(42) seq_len = 16 - hidden_size = 7168 + hidden_size = 6144 dtype = torch.bfloat16 tensor_parallel_size = mpi_pool_executor.num_workers From 49dec208828c7893b27a491860bfa157ba26aaf1 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:34:22 +0800 Subject: [PATCH 07/20] chore: support bf16 or fp16 as input tensor Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 34 +++++++++---------- .../MiniMaxReduceRMSKernel.h | 3 ++ .../_torch/multi_gpu/test_allreduce.py | 5 ++- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index e505e4744bd8..1e150b8ae6bc 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -134,8 +134,8 @@ public: step 1: reduce from single rank to get the variance sum (reduce(input^2, dim=-1)) step 2: reduce from all ranks to get the variance sum (all_reduce(variance_sum)) step 3: calculate the rms norm (input * rsqrt(variance + eps)) -in this case, max hidden_dim is 6144, for each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) threads -so we can assume cluster size is 1 (tp_size >= 2) +in this case, max hidden_dim is 6144 (float data), for each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) +threads so we can assume cluster size is 1 (tp_size >= 2) */ template __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMaxReduceRMSParams params) @@ -162,25 +162,24 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa int clear_access = comm.clear_size / kElemsPerAccess; for (int idx = access_id; idx < tot_access; idx += access_stride, token_id += token_stride) { - alignas(16) float vals[4]; + alignas(16) DType vals[kElemsPerAccess]; + // we use float to load and store variance sum float sum_variance = 0.F; *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[idx]; #pragma unroll - for (int i = 0; i < 4; ++i) + for (int i = 0; i < kElemsPerAccess; ++i) { - if (is_neg_zero(vals[i])) - { - vals[i] = 0.F; - } - sum_variance += vals[i] * vals[i]; + sum_variance += static_cast(vals[i]) * static_cast(vals[i]); } // step 1: reduce from single rank to get the variance sum tensorrt_llm::common::blockReduceSumV2(&sum_variance); - + if (is_neg_zero(sum_variance)) + { + sum_variance = 0.F; + } // step 2: reduce from all ranks to get the variance sum // be careful, we only use float to load and store variance sum // but we use float4 to load input tensor - // constexpr int StrideGap = kElemsPerAccess; // Push data to other ranks // we only need the first thread to push data to other ranks if (threadIdx.x == 0) @@ -217,15 +216,16 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa // load norm weight // TODO: correct the access_id_in_token - __nv_bfloat16 norm_weight[4]; - *reinterpret_cast<__nv_bfloat164*>(norm_weight) - = reinterpret_cast<__nv_bfloat164*>(params.rms_gamma)[access_id_in_token]; + __nv_bfloat16 norm_weight[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type*>(params.rms_gamma)[access_id_in_token]; #pragma unroll - for (int i = 0; i < 4; ++i) + for (int i = 0; i < kElemsPerAccess; ++i) { - vals[i] = vals[i] * rsqrtf((sum_variance / static_cast(params.hidden_dim) / NRanks) + params.rms_eps) - * static_cast(norm_weight[i]); + vals[i] = static_cast(static_cast(vals[i]) + * rsqrtf((sum_variance / static_cast(params.hidden_dim) / NRanks) + params.rms_eps) + * static_cast(norm_weight[i])); } // step 4: store the rms norm diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h index 67df39ef6ef5..224db393d01f 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -20,18 +20,21 @@ template <> struct ElemsPerAccess { static constexpr int value = 8; + using norm_weight_type = common::__nv_bfloat168; }; template <> struct ElemsPerAccess { static constexpr int value = 8; + using norm_weight_type = common::__nv_bfloat168; }; template <> struct ElemsPerAccess { static constexpr int value = 4; + using norm_weight_type = common::__nv_bfloat164; }; template diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index e41131382a79..d10e878e2fcf 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -731,11 +731,10 @@ def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, rank=tensor_parallel_rank, )) minimax_output = minimax_allreduce_rms( - input=rank_input, + input=rank_input.to(torch.bfloat16), rms_weights=rank_rms_weights.to(torch.bfloat16), eps=eps, - ).to(origin_dtype) - + ) # finally, verify the results torch.testing.assert_close(minimax_output, ref_output, rtol=0.2, atol=0.2) From e934b06059550d57076edf61a889ab74274980f4 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 13 Feb 2026 15:22:53 +0800 Subject: [PATCH 08/20] chore: use origin dtype for input, just use fp32 acc in kernel Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_minimaxm2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 3f3bbd5eb7a9..33bd0aee07c0 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -151,10 +151,10 @@ def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = self.weight * hidden_states.to(input_dtype) """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) + # input_dtype = hidden_states.dtype + # hidden_states = hidden_states.to(torch.float32) rms_norm_out = self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) - return rms_norm_out.to(input_dtype) + return rms_norm_out # It's a little bit tricky to implement special qk norm From 9e9d357a55f1b94a29acc3462cc2c9b419023c24 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:12:15 +0800 Subject: [PATCH 09/20] draft: add float4 kernel, add benchmark script Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 246 +++++++++++++++++- tests/microbenchmarks/minimax_all_reduce.py | 218 ++++++++++++++++ 2 files changed, 453 insertions(+), 11 deletions(-) create mode 100644 tests/microbenchmarks/minimax_all_reduce.py diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 1e150b8ae6bc..0023c953147b 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -80,14 +80,14 @@ __device__ __forceinline__ float4 get_neg_zero() return vec; } -// __device__ __forceinline__ float4 ld_global_volatile(float4* addr) -// { -// float4 val; -// asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];" -// : "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w) -// : "l"(addr)); -// return val; -// } +__device__ __forceinline__ float4 ld_global_volatile(float4* addr) +{ + float4 val; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w) + : "l"(addr)); + return val; +} __device__ __forceinline__ float ld_global_volatile(float* addr) { @@ -245,6 +245,157 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa #endif } +/** + * Float4 variant: process 4 rows at once, allreduce variance sums as float4 for better memory coalescing. + * sum_variance is always float; applies to all DTypes (half, bf16, float). + * When tot_tokens % 4 != 0, the last group pads rows with zeros; padded rows are not written to rms_norm_out. + */ +template +__global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) +{ + int tot_tokens = params.size / params.hidden_dim; + int tot_groups = (tot_tokens + 3) / 4; // ceiling: last group may have 1-3 valid rows + int access_per_row = params.hidden_dim / kElemsPerAccess; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + cg::grid_group grid = cg::this_grid(); + int group_id = grid.cluster_rank(); + int access_id_in_token = cluster.thread_rank(); + int group_stride = grid.num_clusters(); +#else + int group_id = blockIdx.x; + int access_id_in_token = threadIdx.x; + int group_stride = gridDim.x; +#endif + float4 clear_vec = get_neg_zero(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + if constexpr (!TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + LamportComm comm(params.workspace, params.rank); + + for (int g = group_id; g < tot_groups; g += group_stride) + { + alignas(16) DType vals[4][kElemsPerAccess]; + float sum_variance[4] = {0.F, 0.F, 0.F, 0.F}; + + // Load 4 rows and compute partial sum of squares per row (sum_variance always float) + for (int r = 0; r < 4; ++r) + { + int token_r = g * 4 + r; + if (token_r < tot_tokens) + { + int idx_r = token_r * access_per_row + access_id_in_token; + *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + sum_variance[r] += static_cast(vals[r][i]) * static_cast(vals[r][i]); + } + } + else + { + *reinterpret_cast(&vals[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); + sum_variance[r] = 0.F; + } + } + + tensorrt_llm::common::blockReduceSumV2(sum_variance); +#pragma unroll + for (int r = 0; r < 4; ++r) + { + if (is_neg_zero(sum_variance[r])) + { + sum_variance[r] = 0.F; + } + } + + // Allreduce: write float4, volatile read float4 from each rank, component-wise sum + if (threadIdx.x == 0) + { + float4 sum4; + sum4.x = sum_variance[0]; + sum4.y = sum_variance[1]; + sum4.z = sum_variance[2]; + sum4.w = sum_variance[3]; + for (int r = 0; r < NRanks; ++r) + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + } + } + + bool done = false; + float4 vars_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vars_all_ranks[r] + = ld_global_volatile(&reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + done &= !is_neg_zero(vars_all_ranks[r]); + } + } + + sum_variance[0] = 0.F; + sum_variance[1] = 0.F; + sum_variance[2] = 0.F; + sum_variance[3] = 0.F; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + sum_variance[0] += vars_all_ranks[r].x; + sum_variance[1] += vars_all_ranks[r].y; + sum_variance[2] += vars_all_ranks[r].z; + sum_variance[3] += vars_all_ranks[r].w; + } + + // Load norm weight (same column for all 4 rows) + __nv_bfloat16 norm_weight[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma)[access_id_in_token]; + + // RMS norm and store 4 rows (skip write for padded rows) + for (int r = 0; r < 4; ++r) + { + int token_r = g * 4 + r; + if (token_r >= tot_tokens) + { + continue; + } + float scale = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] + = static_cast(static_cast(vals[r][i]) * scale * static_cast(norm_weight[i])); + } + int idx_out = token_r * access_per_row + access_id_in_token; + reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals[r][0]); + } + } + + // Clear comm buffer (tot_groups float4s per rank) + for (int g = group_id; g < tot_groups; g += group_stride) + { + reinterpret_cast(comm.clear_buf)[g] = clear_vec; + } + comm.update(params.size * NRanks); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if constexpr (TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + int get_sm_count() { static int sm_count = 0; @@ -301,20 +452,93 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) } } +template +void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& params) +{ + TLLM_CHECK(params.size % params.hidden_dim == 0); + TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); + int token_num = params.size / params.hidden_dim; + int tot_groups = (token_num + 3) / 4; // ceiling + if (tot_groups == 0) + { + return; + } + static int SM = tensorrt_llm::common::getSMVersion(); + int sm_count = get_sm_count(); + int cluster_size = 1; + int cluster_num = tot_groups; + int threads_per_token = params.hidden_dim / kElemsPerAccess; + int block_size = threads_per_token; + int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; + + cudaLaunchConfig_t cfg; + cfg.gridDim = grid_size; + cfg.blockDim = block_size; + cfg.dynamicSmemBytes = 0; + cfg.stream = params.stream; + + cudaLaunchAttribute attribute[2]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0; + attribute[1].id = cudaLaunchAttributeClusterDimension; + attribute[1].val.clusterDim.x = cluster_size; + attribute[1].val.clusterDim.y = 1; + attribute[1].val.clusterDim.z = 1; + cfg.attrs = attribute; + cfg.numAttrs = SM >= 90 ? 2 : 0; + + bool trigger_completion_at_end = params.trigger_completion_at_end; + if (trigger_completion_at_end) + { + TLLM_CUDA_CHECK( + cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } + else + { + TLLM_CUDA_CHECK( + cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } +} + template void dispatch_dtype(MiniMaxReduceRMSParams const& params) { + int token_num = params.size / params.hidden_dim; + constexpr int kFloat4MinTokens = 32; + bool use_float4 = (token_num >= kFloat4MinTokens); + if (params.dtype == nvinfer1::DataType::kHALF) { - minimax_reduce_rms_kernel_launcher(params); + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4(params); + } + else + { + minimax_reduce_rms_kernel_launcher(params); + } } else if (params.dtype == nvinfer1::DataType::kBF16) { - minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks>(params); + } + else + { + minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); + } } else if (params.dtype == nvinfer1::DataType::kFLOAT) { - minimax_reduce_rms_kernel_launcher(params); + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4(params); + } + else + { + minimax_reduce_rms_kernel_launcher(params); + } } else { diff --git a/tests/microbenchmarks/minimax_all_reduce.py b/tests/microbenchmarks/minimax_all_reduce.py new file mode 100644 index 000000000000..cfc3898efe8c --- /dev/null +++ b/tests/microbenchmarks/minimax_all_reduce.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from itertools import product + +# isort: off +import torch +import pandas as pd + +# isort: on +try: + from cuda.bindings import runtime as cudart +except ImportError: + from cuda import cudart + +import tensorrt_llm as tllm +from tensorrt_llm import Mapping +from tensorrt_llm._torch.distributed import MiniMaxAllReduceRMS +from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, nvtx_range +from tensorrt_llm.bindings.internal.runtime import delay_kernel +from tensorrt_llm.logger import logger +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper + + +def profile_minimax_allreduce_rms( + mapping: Mapping, + op: MiniMaxAllReduceRMS, + enable_cudagraph: bool = False, + inner_loop: int = 200, + outer_loop: int = 10, + input_tensor=None, + norm_weight=None, + eps: float = 1e-5, +): + def func(loop_num=inner_loop): + out = None + for _ in range(loop_num): + out = op(input_tensor, norm_weight, eps) + return out + + start = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + stop = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + graph = torch.cuda.CUDAGraph() + + stream = torch.cuda.Stream() + with ( + torch.cuda.stream(stream), + nvtx_range(f"minimax_allreduce_rms: shape={input_tensor.size(0)}x{input_tensor.size(1)}"), + ): + func(loop_num=1) + + if enable_cudagraph: + for i in range(2): + func(loop_num=1) + with torch.cuda.graph(graph, stream=stream): + _ = func() + + delay_kernel(20000, stream) + + torch.cuda.synchronize() + torch.cuda.profiler.start() + + for i in range(outer_loop): + start[i].record(stream) + if enable_cudagraph: + graph.replay() + else: + _ = func() + stop[i].record(stream) + + torch.cuda.synchronize() + torch.cuda.profiler.stop() + runtimes = [start[i].elapsed_time(stop[i]) for i in range(outer_loop)] + median_ms = sorted(runtimes)[len(runtimes) // 2] / inner_loop + return median_ms + + +def minimax_allreduce_benchmark( + dtype: str = "bfloat16", + test_range: str = "256,256000000,10", + enable_cudagraph: bool = False, + explore_2d: bool = False, + save_csv: str = None, +): + world_size = tllm.mpi_world_size() + rank = tllm.mpi_rank() + local_rank = local_mpi_rank() + gpus_per_node = local_mpi_size() + + torch.cuda.set_device(local_rank) + cudart.cudaSetDevice(local_rank) + + mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) + logger.set_rank(mapping.rank) + + if world_size == 1: + raise RuntimeError("Benchmark must run with mpi_world_size > 1") + + torch_dtype = tllm._utils.str_dtype_to_torch(dtype) + + inner_loop = 200 + outer_loop = 10 + eps = 1e-5 + + shape_list = [] + if explore_2d: + num_tokens_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] + hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192] + for num_tokens, hidden_size in product(num_tokens_list, hidden_size_list): + shape_list.append((num_tokens, hidden_size)) + else: + min_size, max_size, ratio = [int(i) for i in test_range.split(",")] + size = min_size + hidden_size = min_size + num_tokens = 1 + while size < max_size: + size *= ratio + shape_list.append((num_tokens, hidden_size)) + if hidden_size * ratio > 4096: + num_tokens *= ratio + else: + hidden_size *= ratio + assert size == num_tokens * hidden_size + + op = MiniMaxAllReduceRMS(mapping=mapping) + max_workspace = CustomAllReduceHelper.max_workspace_size_auto( + mapping.tp_size, support_deterministic=False + ) + + df = pd.DataFrame() + for num_tokens, hidden_size in shape_list: + message_size_bytes = num_tokens * hidden_size * torch.finfo(torch_dtype).bits // 8 + if message_size_bytes > max_workspace: + continue + + input_tensor = torch.ones((num_tokens, hidden_size), dtype=torch_dtype, device="cuda") + norm_weight = torch.randn((hidden_size,), dtype=torch_dtype, device="cuda") + + median_ms = profile_minimax_allreduce_rms( + mapping=mapping, + op=op, + enable_cudagraph=enable_cudagraph, + inner_loop=inner_loop, + outer_loop=outer_loop, + input_tensor=input_tensor, + norm_weight=norm_weight, + eps=eps, + ) + + if mapping.rank == 0: + df = pd.concat( + [ + df, + pd.DataFrame( + { + "world_size": [mapping.world_size], + "dtype": [dtype], + "message_size_bytes": [message_size_bytes], + "num_tokens": [num_tokens], + "hidden_size": [hidden_size], + "time (us)": [median_ms * 1000], + } + ), + ] + ) + print( + f"num_tokens: {num_tokens}, hidden_size: {hidden_size}, " + f"time (us): {median_ms * 1000}" + ) + + if mapping.rank == 0: + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", None) + pd.set_option("display.max_colwidth", None) + print(df) + + if mapping.rank == 0 and save_csv is not None: + df.to_csv(save_csv, index=False) + + return df + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--dtype", "-t", default="bfloat16") + parser.add_argument( + "--range", + "-r", + default="256,256000000,10", + help="min_size,max_size,multiplicative_ratio", + ) + parser.add_argument("--explore_2d", action="store_true", default=False) + parser.add_argument("--enable_cudagraph", action="store_true") + parser.add_argument("--save_csv", type=str, default=None) + + args = parser.parse_args() + + minimax_allreduce_benchmark( + args.dtype, + args.range, + args.enable_cudagraph, + args.explore_2d, + args.save_csv, + ) From 2eaa0e78076ae51540580710a500e9d11b1c35fc Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:43:48 +0800 Subject: [PATCH 10/20] test: add benchmark code, fix clean lamport buffer bug Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 240 +++++++++++++++++- tests/microbenchmarks/minimax_all_reduce.py | 24 +- 2 files changed, 249 insertions(+), 15 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 0023c953147b..aa69cd7793e3 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -285,6 +285,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 float sum_variance[4] = {0.F, 0.F, 0.F, 0.F}; // Load 4 rows and compute partial sum of squares per row (sum_variance always float) +#pragma unroll for (int r = 0; r < 4; ++r) { int token_r = g * 4 + r; @@ -382,12 +383,243 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 } } - // Clear comm buffer (tot_groups float4s per rank) - for (int g = group_id; g < tot_groups; g += group_stride) + // Clear comm buffer: clear full size set by previous kernel (same as non-float4 kernel). + int clear_access = static_cast(comm.clear_size / kElemsPerAccess); + int clear_stride = group_stride * access_per_row; + for (int idx = group_id * access_per_row + threadIdx.x; idx < clear_access; idx += clear_stride) { - reinterpret_cast(comm.clear_buf)[g] = clear_vec; + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - comm.update(params.size * NRanks); + comm.update( + tot_groups * 8 * NRanks); // this size is bf16 elem size, for each token, we use fp32 to store variance sum +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if constexpr (TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + +/** + * Pipelined float4 variant: pre-compute first group's local sum, then in loop process + * "previous" group (allreduce + RMS) while computing "current" group; use fixed + * sum_prev/sum_curr and vals_prev/vals_curr with explicit copy (no iter%2) to avoid + * non-constant indexing and register spill. + */ +template +__global__ void __launch_bounds__(1024) + minimax_reduce_rms_kernel_lamport_float4_pipelined(MiniMaxReduceRMSParams params) +{ + int tot_tokens = params.size / params.hidden_dim; + int tot_groups = (tot_tokens + 3) / 4; + int access_per_row = params.hidden_dim / kElemsPerAccess; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + cg::grid_group grid = cg::this_grid(); + int group_id = grid.cluster_rank(); + int access_id_in_token = cluster.thread_rank(); + int group_stride = grid.num_clusters(); +#else + int group_id = blockIdx.x; + int access_id_in_token = threadIdx.x; + int group_stride = gridDim.x; +#endif + float4 clear_vec = get_neg_zero(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + if constexpr (!TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + LamportComm comm(params.workspace, params.rank); + + __shared__ float sum_global[4]; + + int g_first = group_id; + if (g_first >= tot_groups) + { + return; + } + + // Fixed double-buffer variables (constant indexing, no spill) + alignas(16) DType vals_prev[4][kElemsPerAccess]; + alignas(16) DType vals_curr[4][kElemsPerAccess]; + float sum_prev[4] = {0.F, 0.F, 0.F, 0.F}; + float sum_curr[4] = {0.F, 0.F, 0.F, 0.F}; + + // Pre phase: compute first group into sum_prev, vals_prev + { + int g = g_first; +#pragma unroll + for (int r = 0; r < 4; ++r) + { + int token_r = g * 4 + r; + if (token_r < tot_tokens) + { + int idx_r = token_r * access_per_row + access_id_in_token; + *reinterpret_cast(&vals_prev[r][0]) + = reinterpret_cast(params.allreduce_in)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + sum_prev[r] += static_cast(vals_prev[r][i]) * static_cast(vals_prev[r][i]); + } + } + else + { + *reinterpret_cast(&vals_prev[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); + sum_prev[r] = 0.F; + } + } + tensorrt_llm::common::blockReduceSumV2(sum_prev); +#pragma unroll + for (int r = 0; r < 4; ++r) + { + if (is_neg_zero(sum_prev[r])) + { + sum_prev[r] = 0.F; + } + } + } + + for (int g = g_first; g < tot_groups; g += group_stride) + { + // 1. Process group g: allreduce sum_prev, broadcast, RMS vals_prev and store + if (threadIdx.x == 0) + { + float4 sum4; + sum4.x = sum_prev[0]; + sum4.y = sum_prev[1]; + sum4.z = sum_prev[2]; + sum4.w = sum_prev[3]; + for (int r = 0; r < NRanks; ++r) + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + } + } + + bool done = false; + float4 vars_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vars_all_ranks[r] + = ld_global_volatile(&reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + done &= !is_neg_zero(vars_all_ranks[r]); + } + } + + if (threadIdx.x == 0) + { + sum_global[0] = 0.F; + sum_global[1] = 0.F; + sum_global[2] = 0.F; + sum_global[3] = 0.F; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + sum_global[0] += vars_all_ranks[r].x; + sum_global[1] += vars_all_ranks[r].y; + sum_global[2] += vars_all_ranks[r].z; + sum_global[3] += vars_all_ranks[r].w; + } + } + __syncthreads(); + + __nv_bfloat16 norm_weight[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma)[access_id_in_token]; + +#pragma unroll + for (int r = 0; r < 4; ++r) + { + int token_r = g * 4 + r; + if (token_r >= tot_tokens) + { + continue; + } + float scale = rsqrtf((sum_global[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals_prev[r][i] = static_cast( + static_cast(vals_prev[r][i]) * scale * static_cast(norm_weight[i])); + } + int idx_out = token_r * access_per_row + access_id_in_token; + reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals_prev[r][0]); + } + + // 2. If there is a next group: compute into sum_curr, vals_curr then copy curr -> prev + int g_next = g + group_stride; + if (g_next < tot_groups) + { + sum_curr[0] = 0.F; + sum_curr[1] = 0.F; + sum_curr[2] = 0.F; + sum_curr[3] = 0.F; +#pragma unroll + for (int r = 0; r < 4; ++r) + { + int token_r = g_next * 4 + r; + if (token_r < tot_tokens) + { + int idx_r = token_r * access_per_row + access_id_in_token; + *reinterpret_cast(&vals_curr[r][0]) + = reinterpret_cast(params.allreduce_in)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + sum_curr[r] += static_cast(vals_curr[r][i]) * static_cast(vals_curr[r][i]); + } + } + else + { + *reinterpret_cast(&vals_curr[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); + sum_curr[r] = 0.F; + } + } + tensorrt_llm::common::blockReduceSumV2(sum_curr); +#pragma unroll + for (int r = 0; r < 4; ++r) + { + if (is_neg_zero(sum_curr[r])) + { + sum_curr[r] = 0.F; + } + } + + // Explicit copy curr -> prev (constant indexing) + sum_prev[0] = sum_curr[0]; + sum_prev[1] = sum_curr[1]; + sum_prev[2] = sum_curr[2]; + sum_prev[3] = sum_curr[3]; +#pragma unroll + for (int r = 0; r < 4; ++r) + { +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals_prev[r][i] = vals_curr[r][i]; + } + } + } + } + + // Clear comm buffer: clear full size set by previous kernel (same as non-float4 kernel). + int clear_access = static_cast(comm.clear_size / kElemsPerAccess); + int clear_stride = group_stride * access_per_row; + for (int idx = group_id * access_per_row + threadIdx.x; idx < clear_access; idx += clear_stride) + { + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; + } + comm.update(tot_groups * 8 * NRanks); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { diff --git a/tests/microbenchmarks/minimax_all_reduce.py b/tests/microbenchmarks/minimax_all_reduce.py index cfc3898efe8c..808cab4e9e3c 100644 --- a/tests/microbenchmarks/minimax_all_reduce.py +++ b/tests/microbenchmarks/minimax_all_reduce.py @@ -34,6 +34,9 @@ from tensorrt_llm.logger import logger from tensorrt_llm.plugin.plugin import CustomAllReduceHelper +# MiniMax all-reduce only uses D (hidden_size) 128 and 1536 in practice. +ALLOWED_HIDDEN_SIZES = (128, 1536) + def profile_minimax_allreduce_rms( mapping: Mapping, @@ -118,22 +121,21 @@ def minimax_allreduce_benchmark( shape_list = [] if explore_2d: num_tokens_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] - hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192] + hidden_size_list = list(ALLOWED_HIDDEN_SIZES) for num_tokens, hidden_size in product(num_tokens_list, hidden_size_list): shape_list.append((num_tokens, hidden_size)) else: min_size, max_size, ratio = [int(i) for i in test_range.split(",")] - size = min_size - hidden_size = min_size - num_tokens = 1 - while size < max_size: - size *= ratio - shape_list.append((num_tokens, hidden_size)) - if hidden_size * ratio > 4096: + for hidden_size in ALLOWED_HIDDEN_SIZES: + n_min = max(1, (min_size + hidden_size - 1) // hidden_size) + n_max = max_size // hidden_size + num_tokens = n_min + while num_tokens <= n_max: + shape_list.append((num_tokens, hidden_size)) num_tokens *= ratio - else: - hidden_size *= ratio - assert size == num_tokens * hidden_size + num_tokens = max(num_tokens, 1) + # Only test D (hidden_size) = 128 and 1536 (no-op when explore_2d already uses them) + shape_list = [(n, d) for n, d in shape_list if d in ALLOWED_HIDDEN_SIZES] op = MiniMaxAllReduceRMS(mapping=mapping) max_workspace = CustomAllReduceHelper.max_workspace_size_auto( From 05885e705dcb974e9af10c20bb85e90d6fc26843 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:39:50 +0800 Subject: [PATCH 11/20] feature: fuse q and k rms norm kernel Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 386 +++++++----------- .../MiniMaxReduceRMSKernel.h | 15 +- cpp/tensorrt_llm/thop/allreduceOp.cpp | 52 ++- tensorrt_llm/_torch/distributed/ops.py | 17 + .../_torch/models/modeling_minimaxm2.py | 27 +- tests/microbenchmarks/minimax_all_reduce.py | 115 ++++++ 6 files changed, 350 insertions(+), 262 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index aa69cd7793e3..940156fad021 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -116,7 +116,7 @@ public: #endif access_id = token_id * params.hidden_dim / kElemsPerAccess + access_id_in_token; access_stride = token_stride * params.hidden_dim / kElemsPerAccess; - tot_access = params.size / kElemsPerAccess; + tot_access = params.size_q / kElemsPerAccess; } int token_id; @@ -147,7 +147,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa int access_id = index_helper.access_id; int access_stride = index_helper.access_stride; int tot_access = index_helper.tot_access; - int tot_tokens = params.size / params.hidden_dim; + int tot_tokens = params.size_q / params.hidden_dim; float4 clear_vec = get_neg_zero(); // FusedOp fused_op(params, access_id, access_id_in_token); @@ -236,7 +236,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa // Clear comm buffer that previous kernel used reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - comm.update(params.size * NRanks); + comm.update(params.size_q * NRanks); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { @@ -249,13 +249,15 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa * Float4 variant: process 4 rows at once, allreduce variance sums as float4 for better memory coalescing. * sum_variance is always float; applies to all DTypes (half, bf16, float). * When tot_tokens % 4 != 0, the last group pads rows with zeros; padded rows are not written to rms_norm_out. + * IsQK: when true, process Q+K in one loop with doubled comm buffer; when false, single-matrix (Q only). */ -template +template __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) { - int tot_tokens = params.size / params.hidden_dim; + int tot_tokens = params.size_q / params.hidden_dim; int tot_groups = (tot_tokens + 3) / 4; // ceiling: last group may have 1-3 valid rows - int access_per_row = params.hidden_dim / kElemsPerAccess; + int access_per_row_q = params.hidden_dim / kElemsPerAccess; + int access_per_row_k = IsQK ? (params.hidden_dim_k / kElemsPerAccess) : 0; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace cg = cooperative_groups; cg::cluster_group cluster = cg::this_cluster(); @@ -283,15 +285,16 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 { alignas(16) DType vals[4][kElemsPerAccess]; float sum_variance[4] = {0.F, 0.F, 0.F, 0.F}; + float sum_variance_k[4] = {0.F, 0.F, 0.F, 0.F}; - // Load 4 rows and compute partial sum of squares per row (sum_variance always float) + // Load 4 rows of Q and compute partial sum of squares per row #pragma unroll for (int r = 0; r < 4; ++r) { int token_r = g * 4 + r; if (token_r < tot_tokens) { - int idx_r = token_r * access_per_row + access_id_in_token; + int idx_r = token_r * access_per_row_q + access_id_in_token; *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) @@ -306,7 +309,32 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 } } + // Load 4 rows of K only when this thread is in K column range (access_id_in_token < access_per_row_k) + if constexpr (IsQK) + { +#pragma unroll + for (int r = 0; r < 4; ++r) + { + int token_r = g * 4 + r; + if (token_r < tot_tokens && access_id_in_token < access_per_row_k) + { + int idx_r = token_r * access_per_row_k + access_id_in_token; + alignas(16) DType vals_k[kElemsPerAccess]; + *reinterpret_cast(vals_k) = reinterpret_cast(params.allreduce_in_k)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + sum_variance_k[r] += static_cast(vals_k[i]) * static_cast(vals_k[i]); + } + } + } + } + tensorrt_llm::common::blockReduceSumV2(sum_variance); + if constexpr (IsQK) + { + tensorrt_llm::common::blockReduceSumV2(sum_variance_k); + } #pragma unroll for (int r = 0; r < 4; ++r) { @@ -314,9 +342,16 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 { sum_variance[r] = 0.F; } + if constexpr (IsQK) + { + if (is_neg_zero(sum_variance_k[r])) + { + sum_variance_k[r] = 0.F; + } + } } - // Allreduce: write float4, volatile read float4 from each rank, component-wise sum + // Allreduce: write float4(s) to comm if (threadIdx.x == 0) { float4 sum4; @@ -326,10 +361,29 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 sum4.w = sum_variance[3]; for (int r = 0; r < NRanks; ++r) { - reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + if constexpr (IsQK) + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g] = sum4; + } + else + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + } + } + if constexpr (IsQK) + { + sum4.x = sum_variance_k[0]; + sum4.y = sum_variance_k[1]; + sum4.z = sum_variance_k[2]; + sum4.w = sum_variance_k[3]; + for (int r = 0; r < NRanks; ++r) + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g + 1] = sum4; + } } } + // Read Q from buffer first, sum, then RMS and store Q; then read K, sum, RMS and store K bool done = false; float4 vars_all_ranks[NRanks]; while (!done) @@ -338,12 +392,19 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 #pragma unroll for (int r = 0; r < NRanks; ++r) { - vars_all_ranks[r] - = ld_global_volatile(&reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + if constexpr (IsQK) + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g]); + } + else + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + } done &= !is_neg_zero(vars_all_ranks[r]); } } - sum_variance[0] = 0.F; sum_variance[1] = 0.F; sum_variance[2] = 0.F; @@ -357,13 +418,13 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 sum_variance[3] += vars_all_ranks[r].w; } - // Load norm weight (same column for all 4 rows) + // Load norm weight for Q (same column for all 4 rows) __nv_bfloat16 norm_weight[kElemsPerAccess]; *reinterpret_cast::norm_weight_type*>(norm_weight) = reinterpret_cast::norm_weight_type const*>( params.rms_gamma)[access_id_in_token]; - // RMS norm and store 4 rows (skip write for padded rows) + // RMS norm and store 4 rows of Q (skip write for padded rows) for (int r = 0; r < 4; ++r) { int token_r = g * 4 + r; @@ -378,235 +439,64 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 vals[r][i] = static_cast(static_cast(vals[r][i]) * scale * static_cast(norm_weight[i])); } - int idx_out = token_r * access_per_row + access_id_in_token; + int idx_out = token_r * access_per_row_q + access_id_in_token; reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals[r][0]); } - } - - // Clear comm buffer: clear full size set by previous kernel (same as non-float4 kernel). - int clear_access = static_cast(comm.clear_size / kElemsPerAccess); - int clear_stride = group_stride * access_per_row; - for (int idx = group_id * access_per_row + threadIdx.x; idx < clear_access; idx += clear_stride) - { - reinterpret_cast(comm.clear_buf)[idx] = clear_vec; - } - comm.update( - tot_groups * 8 * NRanks); // this size is bf16 elem size, for each token, we use fp32 to store variance sum -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (TriggerCompletionAtEnd) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif -} -/** - * Pipelined float4 variant: pre-compute first group's local sum, then in loop process - * "previous" group (allreduce + RMS) while computing "current" group; use fixed - * sum_prev/sum_curr and vals_prev/vals_curr with explicit copy (no iter%2) to avoid - * non-constant indexing and register spill. - */ -template -__global__ void __launch_bounds__(1024) - minimax_reduce_rms_kernel_lamport_float4_pipelined(MiniMaxReduceRMSParams params) -{ - int tot_tokens = params.size / params.hidden_dim; - int tot_groups = (tot_tokens + 3) / 4; - int access_per_row = params.hidden_dim / kElemsPerAccess; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - namespace cg = cooperative_groups; - cg::cluster_group cluster = cg::this_cluster(); - cg::grid_group grid = cg::this_grid(); - int group_id = grid.cluster_rank(); - int access_id_in_token = cluster.thread_rank(); - int group_stride = grid.num_clusters(); -#else - int group_id = blockIdx.x; - int access_id_in_token = threadIdx.x; - int group_stride = gridDim.x; -#endif - float4 clear_vec = get_neg_zero(); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaGridDependencySynchronize(); - if constexpr (!TriggerCompletionAtEnd) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - LamportComm comm(params.workspace, params.rank); - - __shared__ float sum_global[4]; - - int g_first = group_id; - if (g_first >= tot_groups) - { - return; - } - - // Fixed double-buffer variables (constant indexing, no spill) - alignas(16) DType vals_prev[4][kElemsPerAccess]; - alignas(16) DType vals_curr[4][kElemsPerAccess]; - float sum_prev[4] = {0.F, 0.F, 0.F, 0.F}; - float sum_curr[4] = {0.F, 0.F, 0.F, 0.F}; - - // Pre phase: compute first group into sum_prev, vals_prev - { - int g = g_first; -#pragma unroll - for (int r = 0; r < 4; ++r) + // Then read K from buffer, sum, RMS and store K (only when IsQK and access_id_in_token < access_per_row_k) + if constexpr (IsQK) { - int token_r = g * 4 + r; - if (token_r < tot_tokens) + float4 vars_k_all_ranks[NRanks]; + done = false; + while (!done) { - int idx_r = token_r * access_per_row + access_id_in_token; - *reinterpret_cast(&vals_prev[r][0]) - = reinterpret_cast(params.allreduce_in)[idx_r]; + done = true; #pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) + for (int r = 0; r < NRanks; ++r) { - sum_prev[r] += static_cast(vals_prev[r][i]) * static_cast(vals_prev[r][i]); + vars_k_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g + 1]); + done &= !is_neg_zero(vars_k_all_ranks[r]); } } - else - { - *reinterpret_cast(&vals_prev[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); - sum_prev[r] = 0.F; - } - } - tensorrt_llm::common::blockReduceSumV2(sum_prev); + sum_variance_k[0] = 0.F; + sum_variance_k[1] = 0.F; + sum_variance_k[2] = 0.F; + sum_variance_k[3] = 0.F; #pragma unroll - for (int r = 0; r < 4; ++r) - { - if (is_neg_zero(sum_prev[r])) - { - sum_prev[r] = 0.F; - } - } - } - - for (int g = g_first; g < tot_groups; g += group_stride) - { - // 1. Process group g: allreduce sum_prev, broadcast, RMS vals_prev and store - if (threadIdx.x == 0) - { - float4 sum4; - sum4.x = sum_prev[0]; - sum4.y = sum_prev[1]; - sum4.z = sum_prev[2]; - sum4.w = sum_prev[3]; for (int r = 0; r < NRanks; ++r) { - reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + sum_variance_k[0] += vars_k_all_ranks[r].x; + sum_variance_k[1] += vars_k_all_ranks[r].y; + sum_variance_k[2] += vars_k_all_ranks[r].z; + sum_variance_k[3] += vars_k_all_ranks[r].w; } - } - bool done = false; - float4 vars_all_ranks[NRanks]; - while (!done) - { - done = true; -#pragma unroll - for (int r = 0; r < NRanks; ++r) - { - vars_all_ranks[r] - = ld_global_volatile(&reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); - done &= !is_neg_zero(vars_all_ranks[r]); - } - } - - if (threadIdx.x == 0) - { - sum_global[0] = 0.F; - sum_global[1] = 0.F; - sum_global[2] = 0.F; - sum_global[3] = 0.F; -#pragma unroll - for (int r = 0; r < NRanks; ++r) - { - sum_global[0] += vars_all_ranks[r].x; - sum_global[1] += vars_all_ranks[r].y; - sum_global[2] += vars_all_ranks[r].z; - sum_global[3] += vars_all_ranks[r].w; - } - } - __syncthreads(); - - __nv_bfloat16 norm_weight[kElemsPerAccess]; - *reinterpret_cast::norm_weight_type*>(norm_weight) - = reinterpret_cast::norm_weight_type const*>( - params.rms_gamma)[access_id_in_token]; - -#pragma unroll - for (int r = 0; r < 4; ++r) - { - int token_r = g * 4 + r; - if (token_r >= tot_tokens) + if (access_id_in_token < access_per_row_k) { - continue; - } - float scale = rsqrtf((sum_global[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); -#pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) - { - vals_prev[r][i] = static_cast( - static_cast(vals_prev[r][i]) * scale * static_cast(norm_weight[i])); - } - int idx_out = token_r * access_per_row + access_id_in_token; - reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals_prev[r][0]); - } - - // 2. If there is a next group: compute into sum_curr, vals_curr then copy curr -> prev - int g_next = g + group_stride; - if (g_next < tot_groups) - { - sum_curr[0] = 0.F; - sum_curr[1] = 0.F; - sum_curr[2] = 0.F; - sum_curr[3] = 0.F; -#pragma unroll - for (int r = 0; r < 4; ++r) - { - int token_r = g_next * 4 + r; - if (token_r < tot_tokens) + __nv_bfloat16 norm_weight_k[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight_k) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma_k)[access_id_in_token]; + for (int r = 0; r < 4; ++r) { - int idx_r = token_r * access_per_row + access_id_in_token; - *reinterpret_cast(&vals_curr[r][0]) - = reinterpret_cast(params.allreduce_in)[idx_r]; + int token_r = g * 4 + r; + if (token_r >= tot_tokens) + { + continue; + } + alignas(16) DType vals_k[kElemsPerAccess]; + int idx_r = token_r * access_per_row_k + access_id_in_token; + *reinterpret_cast(vals_k) = reinterpret_cast(params.allreduce_in_k)[idx_r]; + float scale_k = rsqrtf( + (sum_variance_k[r] / static_cast(params.hidden_dim_k) / NRanks) + params.rms_eps); #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - sum_curr[r] += static_cast(vals_curr[r][i]) * static_cast(vals_curr[r][i]); + vals_k[i] = static_cast( + static_cast(vals_k[i]) * scale_k * static_cast(norm_weight_k[i])); } - } - else - { - *reinterpret_cast(&vals_curr[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); - sum_curr[r] = 0.F; - } - } - tensorrt_llm::common::blockReduceSumV2(sum_curr); -#pragma unroll - for (int r = 0; r < 4; ++r) - { - if (is_neg_zero(sum_curr[r])) - { - sum_curr[r] = 0.F; - } - } - - // Explicit copy curr -> prev (constant indexing) - sum_prev[0] = sum_curr[0]; - sum_prev[1] = sum_curr[1]; - sum_prev[2] = sum_curr[2]; - sum_prev[3] = sum_curr[3]; -#pragma unroll - for (int r = 0; r < 4; ++r) - { -#pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) - { - vals_prev[r][i] = vals_curr[r][i]; + reinterpret_cast(params.rms_norm_out_k)[idx_r] = *reinterpret_cast(vals_k); } } } @@ -614,12 +504,12 @@ __global__ void __launch_bounds__(1024) // Clear comm buffer: clear full size set by previous kernel (same as non-float4 kernel). int clear_access = static_cast(comm.clear_size / kElemsPerAccess); - int clear_stride = group_stride * access_per_row; - for (int idx = group_id * access_per_row + threadIdx.x; idx < clear_access; idx += clear_stride) + int clear_stride = group_stride * access_per_row_q; + for (int idx = group_id * access_per_row_q + threadIdx.x; idx < clear_access; idx += clear_stride) { reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - comm.update(tot_groups * 8 * NRanks); + comm.update(IsQK ? (2 * tot_groups * 8 * NRanks) : (tot_groups * 8 * NRanks)); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { @@ -645,10 +535,10 @@ int get_sm_count() template void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) { - TLLM_CHECK(params.size % params.hidden_dim == 0); + TLLM_CHECK(params.size_q % params.hidden_dim == 0); TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); static int SM = tensorrt_llm::common::getSMVersion(); - int token_num = params.size / params.hidden_dim; + int token_num = params.size_q / params.hidden_dim; // for current problem size, we only need one cluster int sm_count = get_sm_count(); int cluster_size = 1; @@ -687,9 +577,16 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) template void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& params) { - TLLM_CHECK(params.size % params.hidden_dim == 0); + TLLM_CHECK(params.size_q % params.hidden_dim == 0); TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); - int token_num = params.size / params.hidden_dim; + if (params.allreduce_in_k != nullptr) + { + TLLM_CHECK(params.hidden_dim >= params.hidden_dim_k); + TLLM_CHECK(params.size_k % params.hidden_dim_k == 0); + TLLM_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); + TLLM_CHECK(params.size_q / params.hidden_dim == params.size_k / params.hidden_dim_k); + } + int token_num = params.size_q / params.hidden_dim; int tot_groups = (token_num + 3) / 4; // ceiling if (tot_groups == 0) { @@ -720,22 +617,39 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par cfg.numAttrs = SM >= 90 ? 2 : 0; bool trigger_completion_at_end = params.trigger_completion_at_end; + bool is_qk = (params.allreduce_in_k != nullptr); if (trigger_completion_at_end) { - TLLM_CUDA_CHECK( - cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + if (is_qk) + { + TLLM_CUDA_CHECK( + cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } + else + { + TLLM_CUDA_CHECK( + cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } } else { - TLLM_CUDA_CHECK( - cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + if (is_qk) + { + TLLM_CUDA_CHECK( + cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } + else + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx( + &cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + } } } template void dispatch_dtype(MiniMaxReduceRMSParams const& params) { - int token_num = params.size / params.hidden_dim; + int token_num = params.size_q / params.hidden_dim; constexpr int kFloat4MinTokens = 32; bool use_float4 = (token_num >= kFloat4MinTokens); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h index 224db393d01f..1b0ddd7feef6 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -45,12 +45,17 @@ struct MiniMaxReduceRMSParams int nranks{}; int rank{}; nvinfer1::DataType dtype; - int size{}; - int hidden_dim{}; + int size_q{}; // numel of Q (num_token * head_dim_q) + int hidden_dim{}; // head_dim_q + int size_k{}; // numel of K (num_token * head_dim_k) + int hidden_dim_k{}; // head_dim_k; must have head_dim_q >= head_dim_k void** workspace{}; - void* allreduce_in{}; - void* rms_norm_out{}; - void* rms_gamma{}; + void* allreduce_in{}; // Q input + void* rms_norm_out{}; // Q output + void* rms_gamma{}; // Q norm weight + void* allreduce_in_k{}; // K input (nullptr for single-matrix path) + void* rms_norm_out_k{}; // K output + void* rms_gamma_k{}; // K norm weight float rms_eps{}; cudaStream_t stream{}; bool trigger_completion_at_end = true; diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 87b61def8b32..ddedfb6b30cd 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1832,7 +1832,7 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor co allreduce_params.nranks = static_cast(nranks); allreduce_params.rank = static_cast(rank); allreduce_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); - allreduce_params.size = static_cast(input.numel()); + allreduce_params.size_q = static_cast(input.numel()); allreduce_params.hidden_dim = static_cast(input.size(-1)); allreduce_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); allreduce_params.allreduce_in = input.data_ptr(); @@ -1849,6 +1849,44 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor co return rms_norm_out; } +std::vector minimax_allreduce_rms_qk(torch::Tensor const& q, torch::Tensor const& k, + torch::Tensor const& norm_weight_q, torch::Tensor const& norm_weight_k, torch::Tensor workspace, int64_t const rank, + int64_t const nranks, double const eps, bool const trigger_completion_at_end_) +{ + TORCH_CHECK(q.scalar_type() == k.scalar_type(), "minimax_allreduce_rms_qk: q and k must have same dtype"); + TORCH_CHECK(q.dim() == 2 && k.dim() == 2, "minimax_allreduce_rms_qk: q and k must be 2D"); + TORCH_CHECK(q.size(0) == k.size(0), "minimax_allreduce_rms_qk: q and k must have same num_token"); + int64_t head_dim_q = q.size(-1); + int64_t head_dim_k = k.size(-1); + TORCH_CHECK(head_dim_q >= head_dim_k, "minimax_allreduce_rms_qk: head_dim_q must be >= head_dim_k"); + + auto params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); + params.nranks = static_cast(nranks); + params.rank = static_cast(rank); + params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(q.scalar_type()); + params.size_q = static_cast(q.numel()); + params.hidden_dim = static_cast(head_dim_q); + params.size_k = static_cast(k.numel()); + params.hidden_dim_k = static_cast(head_dim_k); + params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); + params.allreduce_in = q.data_ptr(); + params.rms_gamma = norm_weight_q.data_ptr(); + params.allreduce_in_k = k.data_ptr(); + params.rms_gamma_k = norm_weight_k.data_ptr(); + params.rms_eps = static_cast(eps); + params.stream = at::cuda::getCurrentCUDAStream(q.get_device()); + params.trigger_completion_at_end = trigger_completion_at_end_; + + torch::Tensor rms_norm_out_q = torch::empty_like(q); + torch::Tensor rms_norm_out_k = torch::empty_like(k); + params.rms_norm_out = rms_norm_out_q.mutable_data_ptr(); + params.rms_norm_out_k = rms_norm_out_k.mutable_data_ptr(); + + tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(params); + + return {rms_norm_out_q, rms_norm_out_k}; +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -1922,6 +1960,17 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "int nranks," "float eps," "bool trigger_completion_at_end) -> Tensor"); + m.def( + "minimax_allreduce_rms_qk(" + "Tensor q," + "Tensor k," + "Tensor norm_weight_q," + "Tensor norm_weight_k," + "Tensor workspace," + "int rank," + "int nranks," + "float eps," + "bool trigger_completion_at_end) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) @@ -1933,6 +1982,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce); m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer); m.impl("minimax_allreduce_rms", &tensorrt_llm::torch_ext::minimax_allreduce_rms); + m.impl("minimax_allreduce_rms_qk", &tensorrt_llm::torch_ext::minimax_allreduce_rms_qk); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 98cc3a4fd367..b7d84fc1e03d 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1194,3 +1194,20 @@ def forward(self, input: torch.Tensor, rms_weights: torch.Tensor, self.mapping.tp_rank, self.mapping.tp_size, eps, False) + + def forward_qk(self, q: torch.Tensor, k: torch.Tensor, + rms_weights_q: torch.Tensor, rms_weights_k: torch.Tensor, + eps: float): + """Fused Q+K RMS norm with allreduce. Returns (q_out, k_out).""" + out_list = torch.ops.trtllm.minimax_allreduce_rms_qk( + q, + k, + rms_weights_q, + rms_weights_k, + self.workspace, + self.mapping.tp_rank, + self.mapping.tp_size, + eps, + False, + ) + return (out_list[0], out_list[1]) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 33bd0aee07c0..b496febacdfb 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -217,26 +217,13 @@ def __init__( ) def apply_qk_norm(self, q, k): - q = self.q_norm(q) - k = self.k_norm(k) - # if self.qkv_proj.mapping.tp_size > 1: - # # collect q and k from all gpus - # from ..distributed import allgather - - # temp_q = allgather(q, self.qkv_proj.mapping) - # temp_k = allgather(k, self.qkv_proj.mapping) - # temp_q = self.q_norm(temp_q) - # temp_k = self.k_norm(temp_k) - # q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( - # -1, self.q_size - # ) - # k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( - # -1, self.kv_size - # ) - # else: - # q = self.q_norm(q) - # k = self.k_norm(k) - + if self.qkv_proj.mapping.tp_size > 1: + q, k = self.q_norm.minimax_all_reduce_rms.forward_qk( + q, k, self.q_norm.weight, self.k_norm.weight, self.q_norm.eps + ) + else: + q = self.q_norm(q) + k = self.k_norm(k) return q, k def apply_rope( diff --git a/tests/microbenchmarks/minimax_all_reduce.py b/tests/microbenchmarks/minimax_all_reduce.py index 808cab4e9e3c..7c7e86ee5768 100644 --- a/tests/microbenchmarks/minimax_all_reduce.py +++ b/tests/microbenchmarks/minimax_all_reduce.py @@ -37,6 +37,10 @@ # MiniMax all-reduce only uses D (hidden_size) 128 and 1536 in practice. ALLOWED_HIDDEN_SIZES = (128, 1536) +# Q+K fused API benchmark dimensions +QK_Q_DIM = 1536 +QK_K_DIM = 128 + def profile_minimax_allreduce_rms( mapping: Mapping, @@ -91,6 +95,64 @@ def func(loop_num=inner_loop): return median_ms +def profile_minimax_allreduce_rms_qk( + mapping: Mapping, + op: MiniMaxAllReduceRMS, + enable_cudagraph: bool = False, + inner_loop: int = 200, + outer_loop: int = 10, + q_tensor=None, + k_tensor=None, + norm_weight_q=None, + norm_weight_k=None, + eps: float = 1e-5, +): + """Profile the fused Q+K minimax allreduce RMS API (forward_qk).""" + + def func(loop_num=inner_loop): + out_q, out_k = None, None + for _ in range(loop_num): + out_q, out_k = op.forward_qk(q_tensor, k_tensor, norm_weight_q, norm_weight_k, eps) + return (out_q, out_k) + + start = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + stop = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + graph = torch.cuda.CUDAGraph() + + stream = torch.cuda.Stream() + n_tok, q_d, k_d = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1) + with ( + torch.cuda.stream(stream), + nvtx_range(f"minimax_allreduce_rms_qk: shape={n_tok}x{q_d}+{n_tok}x{k_d}"), + ): + func(loop_num=1) + + if enable_cudagraph: + for i in range(2): + func(loop_num=1) + with torch.cuda.graph(graph, stream=stream): + _ = func() + + delay_kernel(20000, stream) + + torch.cuda.synchronize() + torch.cuda.profiler.start() + + for i in range(outer_loop): + start[i].record(stream) + if enable_cudagraph: + graph.replay() + else: + _ = func() + stop[i].record(stream) + + torch.cuda.synchronize() + torch.cuda.profiler.stop() + runtimes = [start[i].elapsed_time(stop[i]) for i in range(outer_loop)] + median_ms = sorted(runtimes)[len(runtimes) // 2] / inner_loop + return median_ms + + def minimax_allreduce_benchmark( dtype: str = "bfloat16", test_range: str = "256,256000000,10", @@ -170,9 +232,12 @@ def minimax_allreduce_benchmark( { "world_size": [mapping.world_size], "dtype": [dtype], + "api": ["single"], "message_size_bytes": [message_size_bytes], "num_tokens": [num_tokens], "hidden_size": [hidden_size], + "q_dim": [pd.NA], + "k_dim": [pd.NA], "time (us)": [median_ms * 1000], } ), @@ -183,6 +248,56 @@ def minimax_allreduce_benchmark( f"time (us): {median_ms * 1000}" ) + # Q+K fused API benchmark: q_dim=1536, k_dim=128 + num_tokens_qk = sorted({n for n, _ in shape_list}) + for num_tokens in num_tokens_qk: + q_tensor = torch.ones((num_tokens, QK_Q_DIM), dtype=torch_dtype, device="cuda") + k_tensor = torch.ones((num_tokens, QK_K_DIM), dtype=torch_dtype, device="cuda") + norm_weight_q = torch.randn((QK_Q_DIM,), dtype=torch_dtype, device="cuda") + norm_weight_k = torch.randn((QK_K_DIM,), dtype=torch_dtype, device="cuda") + message_size_bytes_qk = ( + num_tokens * (QK_Q_DIM + QK_K_DIM) * torch.finfo(torch_dtype).bits // 8 + ) + if message_size_bytes_qk > max_workspace: + continue + + median_ms_qk = profile_minimax_allreduce_rms_qk( + mapping=mapping, + op=op, + enable_cudagraph=enable_cudagraph, + inner_loop=inner_loop, + outer_loop=outer_loop, + q_tensor=q_tensor, + k_tensor=k_tensor, + norm_weight_q=norm_weight_q, + norm_weight_k=norm_weight_k, + eps=eps, + ) + + if mapping.rank == 0: + df = pd.concat( + [ + df, + pd.DataFrame( + { + "world_size": [mapping.world_size], + "dtype": [dtype], + "api": ["qk"], + "message_size_bytes": [message_size_bytes_qk], + "num_tokens": [num_tokens], + "hidden_size": [pd.NA], + "q_dim": [QK_Q_DIM], + "k_dim": [QK_K_DIM], + "time (us)": [median_ms_qk * 1000], + } + ), + ] + ) + print( + f"qk: num_tokens: {num_tokens}, q_dim: {QK_Q_DIM}, k_dim: {QK_K_DIM}, " + f"time (us): {median_ms_qk * 1000}" + ) + if mapping.rank == 0: pd.set_option("display.max_rows", None) pd.set_option("display.max_columns", None) From 78e38cfb403cf93eadd3175768c6e5f3f00e4bd9 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 12 Mar 2026 22:15:36 +0800 Subject: [PATCH 12/20] feature: q and k use different thread idx Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 239 ++++++++++-------- 1 file changed, 128 insertions(+), 111 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 940156fad021..0a9054acc32d 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -13,6 +13,8 @@ namespace { // anonymous namespace template + +#define MINIMAX_REDUCE_RMS_WARP_SIZE 32 struct LamportComm { __device__ __forceinline__ LamportComm(void** workspace, int rank) @@ -258,6 +260,8 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 int tot_groups = (tot_tokens + 3) / 4; // ceiling: last group may have 1-3 valid rows int access_per_row_q = params.hidden_dim / kElemsPerAccess; int access_per_row_k = IsQK ? (params.hidden_dim_k / kElemsPerAccess) : 0; + int q_warps = (access_per_row_q + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; + int k_warps = IsQK ? ((access_per_row_k + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE) : 0; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace cg = cooperative_groups; cg::cluster_group cluster = cg::this_cluster(); @@ -270,6 +274,9 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 int access_id_in_token = threadIdx.x; int group_stride = gridDim.x; #endif + bool is_q = (access_id_in_token < q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE); + int k_thread_idx = IsQK ? (access_id_in_token - q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE) : 0; + bool is_valid_token = is_q ? (access_id_in_token < access_per_row_q) : (k_thread_idx < access_per_row_k); float4 clear_vec = get_neg_zero(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -283,53 +290,57 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 for (int g = group_id; g < tot_groups; g += group_stride) { - alignas(16) DType vals[4][kElemsPerAccess]; + alignas(16) DType vals[4][kElemsPerAccess]{}; float sum_variance[4] = {0.F, 0.F, 0.F, 0.F}; float sum_variance_k[4] = {0.F, 0.F, 0.F, 0.F}; - // Load 4 rows of Q and compute partial sum of squares per row -#pragma unroll - for (int r = 0; r < 4; ++r) + if (is_q) { - int token_r = g * 4 + r; - if (token_r < tot_tokens) +// Q branch: each thread only covers 128bit +#pragma unroll + for (int r = 0; r < 4; ++r) { + int token_r = g * 4 + r; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; + } int idx_r = token_r * access_per_row_q + access_id_in_token; *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - sum_variance[r] += static_cast(vals[r][i]) * static_cast(vals[r][i]); + float x = static_cast(vals[r][i]); + sum_variance[r] += x * x; } } - else - { - *reinterpret_cast(&vals[r][0]) = make_float4(0.F, 0.F, 0.F, 0.F); - sum_variance[r] = 0.F; - } } - - // Load 4 rows of K only when this thread is in K column range (access_id_in_token < access_per_row_k) - if constexpr (IsQK) + else if constexpr (IsQK) // k branch { +// K branch: k_thread_idx = threadIdx.x - q_warps, each thread covers 32 K columns #pragma unroll for (int r = 0; r < 4; ++r) { int token_r = g * 4 + r; - if (token_r < tot_tokens && access_id_in_token < access_per_row_k) + if (token_r >= tot_tokens || k_thread_idx >= access_per_row_k) { - int idx_r = token_r * access_per_row_k + access_id_in_token; - alignas(16) DType vals_k[kElemsPerAccess]; - *reinterpret_cast(vals_k) = reinterpret_cast(params.allreduce_in_k)[idx_r]; + continue; + } + + int idx_r = token_r * access_per_row_k + k_thread_idx; + *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in_k)[idx_r]; #pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) - { - sum_variance_k[r] += static_cast(vals_k[i]) * static_cast(vals_k[i]); - } + for (int i = 0; i < kElemsPerAccess; ++i) + { + float x = static_cast(vals[r][i]); + sum_variance_k[r] += x * x; } } } + // Local reduce: only Q segment contributes to sum_variance, only K segment to sum_variance_k + // here we use all threads to reduce sum_variance and sum_variance_k + // TODO: we can do local reduce only within q threads and k threads respectively tensorrt_llm::common::blockReduceSumV2(sum_variance); if constexpr (IsQK) { @@ -351,31 +362,37 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 } } - // Allreduce: write float4(s) to comm - if (threadIdx.x == 0) + // Allreduce: write float4(s) to comm (thread 0 has both after broadcast) + if (threadIdx.x == 0 || threadIdx.x == q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE) { - float4 sum4; - sum4.x = sum_variance[0]; - sum4.y = sum_variance[1]; - sum4.z = sum_variance[2]; - sum4.w = sum_variance[3]; - for (int r = 0; r < NRanks; ++r) + if (is_q) { - if constexpr (IsQK) - { - reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g] = sum4; - } - else + float4 sum4; + sum4.x = sum_variance[0]; + sum4.y = sum_variance[1]; + sum4.z = sum_variance[2]; + sum4.w = sum_variance[3]; +#pragma unroll + for (int r = 0; r < NRanks; ++r) { - reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + if constexpr (IsQK) + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g] = sum4; + } + else + { + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + } } } - if constexpr (IsQK) + else if constexpr (IsQK) { + float4 sum4; sum4.x = sum_variance_k[0]; sum4.y = sum_variance_k[1]; sum4.z = sum_variance_k[2]; sum4.w = sum_variance_k[3]; +#pragma unroll for (int r = 0; r < NRanks; ++r) { reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g + 1] = sum4; @@ -383,28 +400,45 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 } } - // Read Q from buffer first, sum, then RMS and store Q; then read K, sum, RMS and store K + // Read Q from buffer, sum, then RMS and store Q bool done = false; float4 vars_all_ranks[NRanks]; - while (!done) + if (is_q) { - done = true; -#pragma unroll - for (int r = 0; r < NRanks; ++r) + while (!done) { - if constexpr (IsQK) + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g]); + if constexpr (IsQK) + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g]); + } + else + { + vars_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + } + done &= !is_neg_zero(vars_all_ranks[r]); } - else + } + } + else if constexpr (IsQK) + { + while (!done) + { + done = true; + for (int r = 0; r < NRanks; ++r) { vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); + &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g + 1]); + done &= !is_neg_zero(vars_all_ranks[r]); } - done &= !is_neg_zero(vars_all_ranks[r]); } } + sum_variance[0] = 0.F; sum_variance[1] = 0.F; sum_variance[2] = 0.F; @@ -418,66 +452,48 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 sum_variance[3] += vars_all_ranks[r].w; } - // Load norm weight for Q (same column for all 4 rows) - __nv_bfloat16 norm_weight[kElemsPerAccess]; - *reinterpret_cast::norm_weight_type*>(norm_weight) - = reinterpret_cast::norm_weight_type const*>( - params.rms_gamma)[access_id_in_token]; - - // RMS norm and store 4 rows of Q (skip write for padded rows) - for (int r = 0; r < 4; ++r) + // RMS norm and store 4 rows of Q (Q branch only, reload and store per column) + if (is_q) { - int token_r = g * 4 + r; - if (token_r >= tot_tokens) - { - continue; - } - float scale = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); -#pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) + if (access_id_in_token < access_per_row_q) { - vals[r][i] - = static_cast(static_cast(vals[r][i]) * scale * static_cast(norm_weight[i])); - } - int idx_out = token_r * access_per_row_q + access_id_in_token; - reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals[r][0]); - } - // Then read K from buffer, sum, RMS and store K (only when IsQK and access_id_in_token < access_per_row_k) - if constexpr (IsQK) - { - float4 vars_k_all_ranks[NRanks]; - done = false; - while (!done) - { - done = true; + __nv_bfloat16 norm_weight[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma)[access_id_in_token]; #pragma unroll - for (int r = 0; r < NRanks; ++r) + for (int r = 0; r < 4; ++r) { - vars_k_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g + 1]); - done &= !is_neg_zero(vars_k_all_ranks[r]); - } - } - sum_variance_k[0] = 0.F; - sum_variance_k[1] = 0.F; - sum_variance_k[2] = 0.F; - sum_variance_k[3] = 0.F; + int token_r = g * 4 + r; + if (token_r >= tot_tokens) + { + continue; + } + float scale + = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); + #pragma unroll - for (int r = 0; r < NRanks; ++r) - { - sum_variance_k[0] += vars_k_all_ranks[r].x; - sum_variance_k[1] += vars_k_all_ranks[r].y; - sum_variance_k[2] += vars_k_all_ranks[r].z; - sum_variance_k[3] += vars_k_all_ranks[r].w; + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] = static_cast( + static_cast(vals[r][i]) * scale * static_cast(norm_weight[i])); + } + int idx_out = token_r * access_per_row_q + access_id_in_token; + reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals[r][0]); + } } + } + else if constexpr (IsQK) + { - if (access_id_in_token < access_per_row_k) + if (k_thread_idx < access_per_row_k) { __nv_bfloat16 norm_weight_k[kElemsPerAccess]; *reinterpret_cast::norm_weight_type*>(norm_weight_k) = reinterpret_cast::norm_weight_type const*>( - params.rms_gamma_k)[access_id_in_token]; + params.rms_gamma_k)[k_thread_idx]; +#pragma unroll for (int r = 0; r < 4; ++r) { int token_r = g * 4 + r; @@ -485,30 +501,30 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 { continue; } - alignas(16) DType vals_k[kElemsPerAccess]; - int idx_r = token_r * access_per_row_k + access_id_in_token; - *reinterpret_cast(vals_k) = reinterpret_cast(params.allreduce_in_k)[idx_r]; float scale_k = rsqrtf( (sum_variance_k[r] / static_cast(params.hidden_dim_k) / NRanks) + params.rms_eps); #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - vals_k[i] = static_cast( - static_cast(vals_k[i]) * scale_k * static_cast(norm_weight_k[i])); + vals[r][i] = static_cast( + static_cast(vals[r][i]) * scale_k * static_cast(norm_weight_k[i])); } - reinterpret_cast(params.rms_norm_out_k)[idx_r] = *reinterpret_cast(vals_k); + int idx_out = token_r * access_per_row_k + k_thread_idx; + reinterpret_cast(params.rms_norm_out_k)[idx_out] = *reinterpret_cast(&vals[r][0]); } } } } - // Clear comm buffer: clear full size set by previous kernel (same as non-float4 kernel). + // Clear comm buffer int clear_access = static_cast(comm.clear_size / kElemsPerAccess); - int clear_stride = group_stride * access_per_row_q; - for (int idx = group_id * access_per_row_q + threadIdx.x; idx < clear_access; idx += clear_stride) + int clear_stride = group_stride * blockDim.x; + + for (int idx = group_id * blockDim.x + threadIdx.x; idx < clear_access; idx += clear_stride) { reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } + comm.update(IsQK ? (2 * tot_groups * 8 * NRanks) : (tot_groups * 8 * NRanks)); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) @@ -596,8 +612,11 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par int sm_count = get_sm_count(); int cluster_size = 1; int cluster_num = tot_groups; - int threads_per_token = params.hidden_dim / kElemsPerAccess; - int block_size = threads_per_token; + int access_per_row_q = params.hidden_dim / kElemsPerAccess; + int access_per_row_k = (params.allreduce_in_k != nullptr) ? (params.hidden_dim_k / kElemsPerAccess) : 0; + auto divUp = [](int a, int b) { return (a + b - 1) / b * b; }; // round up to the nearest multiple of b + int block_size = divUp(access_per_row_q, MINIMAX_REDUCE_RMS_WARP_SIZE) + + ((params.allreduce_in_k != nullptr) ? divUp(access_per_row_k, MINIMAX_REDUCE_RMS_WARP_SIZE) : 0); int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; cudaLaunchConfig_t cfg; @@ -649,9 +668,7 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par template void dispatch_dtype(MiniMaxReduceRMSParams const& params) { - int token_num = params.size_q / params.hidden_dim; - constexpr int kFloat4MinTokens = 32; - bool use_float4 = (token_num >= kFloat4MinTokens); + bool use_float4 = true; if (params.dtype == nvinfer1::DataType::kHALF) { From a610f728979f3c3bb07fd27693cdc37355bf0183 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:46:56 +0800 Subject: [PATCH 13/20] test: add test case for qk norm, add range reduce func Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 73 +++++++++- .../_torch/multi_gpu/test_allreduce.py | 125 ++++++++++++++++++ 2 files changed, 195 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 0a9054acc32d..c58a36b16e2c 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -98,6 +98,70 @@ __device__ __forceinline__ float ld_global_volatile(float* addr) return val; } +template +__device__ __forceinline__ void blockReduceSumRange(T* val, int rangeStart, int rangeEnd) +{ + constexpr int kWarpSize = 32; + constexpr unsigned kFullMask = 0xffffffffu; + static __shared__ T shared[NUM][33]; + + int const activeThreadCount = max(rangeEnd - rangeStart, 0); + bool const isActive = threadIdx.x >= rangeStart && threadIdx.x < rangeEnd; + int const lane = threadIdx.x & (kWarpSize - 1); + unsigned const activeMask = __ballot_sync(kFullMask, isActive); + + if (isActive) + { +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + T sum = val[i]; +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) + { + sum += __shfl_down_sync(activeMask, sum, offset, kWarpSize); + } + val[i] = sum; + } + } + + if (isActive && lane == 0) + { + int const localWarpId = (threadIdx.x - rangeStart) >> 5; +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + shared[i][localWarpId] = val[i]; + } + } + + __syncthreads(); + + int const shiftedTid = threadIdx.x - rangeStart; + int const warpCount = (activeThreadCount + kWarpSize - 1) / kWarpSize; + bool const inLeaderWarp = shiftedTid >= 0 && shiftedTid < kWarpSize; + bool const leaderLaneIsValid = inLeaderWarp && shiftedTid < warpCount; + unsigned const leaderMask = __ballot_sync(kFullMask, leaderLaneIsValid); + + if (inLeaderWarp) + { +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + T sum = leaderLaneIsValid ? shared[i][shiftedTid] : static_cast(0); +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) + { + sum += __shfl_down_sync(leaderMask, sum, offset, kWarpSize); + } + if (threadIdx.x == rangeStart) + { + val[i] = sum; + } + } + } +} + template class IndexHelper { @@ -344,8 +408,11 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 tensorrt_llm::common::blockReduceSumV2(sum_variance); if constexpr (IsQK) { - tensorrt_llm::common::blockReduceSumV2(sum_variance_k); + int const kStartThread = q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE; + int const kEndThread = (q_warps + k_warps) * MINIMAX_REDUCE_RMS_WARP_SIZE; + blockReduceSumRange(sum_variance_k, kStartThread, kEndThread); } + #pragma unroll for (int r = 0; r < 4; ++r) { @@ -501,8 +568,8 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 { continue; } - float scale_k = rsqrtf( - (sum_variance_k[r] / static_cast(params.hidden_dim_k) / NRanks) + params.rms_eps); + float scale_k + = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim_k) / NRanks) + params.rms_eps); #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index d10e878e2fcf..91736d471f9a 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -777,3 +777,128 @@ def test_minimax_allreduce_rms(mpi_pool_executor): ) for r in results: assert r is True + + +@torch.inference_mode() +def run_minimax_allreduce_rms_qk_op(q_input: torch.Tensor, + k_input: torch.Tensor, + tensor_parallel_size: int, + tensor_parallel_rank: int, + rms_weights_q: torch.Tensor, + rms_weights_k: torch.Tensor, eps: float): + torch.manual_seed(42) + + num_tokens = q_input.shape[0] + origin_dtype = q_input.dtype + + q_input = q_input.cuda() + k_input = k_input.cuda() + rms_weights_q = rms_weights_q.cuda() + rms_weights_k = rms_weights_k.cuda() + + # firstly, calculate the reference output for each rank + # Reference: each rank computes RMS norm on its local Q/K independently, + # then all-reduce is applied to the variance before normalization + # The all-reduce sum happens across TP ranks, so we need to simulate that + q_input_fp32 = q_input.to(torch.float32) + k_input_fp32 = k_input.to(torch.float32) + + # Compute reference: RMS norm with all-reduced variance + q_variance_sum = q_input_fp32.pow(2).mean(-1, keepdim=True) + k_variance_sum = k_input_fp32.pow(2).mean(-1, keepdim=True) + + ref_q_output = q_input_fp32 * torch.rsqrt(q_variance_sum + eps) + ref_k_output = k_input_fp32 * torch.rsqrt(k_variance_sum + eps) + + # Apply weights + ref_q_output = ref_q_output * rms_weights_q.to(torch.float32) + ref_k_output = ref_k_output * rms_weights_k.to(torch.float32) + + ref_q_output = ref_q_output.to(origin_dtype) + ref_k_output = ref_k_output.to(origin_dtype) + + # we only need to compare the reference output of the current rank + ref_q_output = ref_q_output.reshape(num_tokens, tensor_parallel_size, -1) + ref_k_output = ref_k_output.reshape(num_tokens, tensor_parallel_size, -1) + ref_q_output = ref_q_output[:, tensor_parallel_rank, :].contiguous() + ref_k_output = ref_k_output[:, tensor_parallel_rank, :].contiguous() + + # minimax input should be sliced by rank + q_input = q_input.reshape(num_tokens, tensor_parallel_size, -1) + k_input = k_input.reshape(num_tokens, tensor_parallel_size, -1) + rank_q_input = q_input[:, tensor_parallel_rank, :].contiguous() + rank_k_input = k_input[:, tensor_parallel_rank, :].contiguous() + + # rms weights should be sliced by rank + rms_weights_q = rms_weights_q.reshape(tensor_parallel_size, -1) + rms_weights_k = rms_weights_k.reshape(tensor_parallel_size, -1) + rank_rms_weights_q = rms_weights_q[tensor_parallel_rank, :].contiguous() + rank_rms_weights_k = rms_weights_k[tensor_parallel_rank, :].contiguous() + + # then, calculate the minimax allreduce output + minimax_allreduce_rms = MiniMaxAllReduceRMS(mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + )) + minimax_q_output, minimax_k_output = minimax_allreduce_rms.forward_qk( + q=rank_q_input, + k=rank_k_input, + rms_weights_q=rank_rms_weights_q, + rms_weights_k=rank_rms_weights_k, + eps=eps, + ) + + # finally, verify the results + torch.testing.assert_close(minimax_q_output, + ref_q_output, + rtol=0.2, + atol=0.2) + torch.testing.assert_close(minimax_k_output, + ref_k_output, + rtol=0.2, + atol=0.2) + + return q_input + + +@torch.inference_mode() +def run_minimax_allreduce_rms_qk_single_rank(tensor_parallel_size, + single_rank_forward_func, q_input, + k_input, rms_weights_q, + rms_weights_k, eps): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + single_rank_forward_func(q_input, k_input, tensor_parallel_size, rank, + rms_weights_q, rms_weights_k, eps) + except Exception: + traceback.print_exc() + raise + return True + + +@pytest.mark.parametrize("mpi_pool_executor", [4], indirect=True) +def test_minimax_allreduce_rms_qk(mpi_pool_executor): + torch.manual_seed(42) + + seq_len = 1024 + q_size = 6144 + k_size = 1024 + dtype = torch.bfloat16 + tensor_parallel_size = mpi_pool_executor.num_workers + + q_input = torch.randn((seq_len, q_size), dtype=dtype) + k_input = torch.randn((seq_len, k_size), dtype=dtype) + rms_weights_q = torch.randn((q_size, ), dtype=dtype, device="cuda") + rms_weights_k = torch.randn((k_size, ), dtype=dtype, device="cuda") + eps = 1e-5 + + results = mpi_pool_executor.map( + run_minimax_allreduce_rms_qk_single_rank, + *zip(*[(tensor_parallel_size, run_minimax_allreduce_rms_qk_op, q_input, + k_input, rms_weights_q, rms_weights_k, eps)] * + tensor_parallel_size), + ) + for r in results: + assert r is True From 55739c24cca318fde5c76b336fc6aaca10dbae42 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 27 Mar 2026 13:19:09 +0800 Subject: [PATCH 14/20] feature: modify float4 kernel, split q, k to different warp, enable pdl Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 484 ++++++++++-------- cpp/tensorrt_llm/thop/allreduceOp.cpp | 1 + tensorrt_llm/_torch/distributed/ops.py | 4 +- tests/microbenchmarks/minimax_all_reduce.py | 172 +++---- 4 files changed, 353 insertions(+), 308 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index c58a36b16e2c..78697f5cc74c 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -82,6 +82,25 @@ __device__ __forceinline__ float4 get_neg_zero() return vec; } +template +__device__ __forceinline__ float rms_rsqrt(float& v, float eps) +{ + constexpr float kInvDim = 1.0F / static_cast(Dim); + v = rsqrtf((v * kInvDim) + eps); + return v; +} + +template +__device__ __forceinline__ float4 rms_rsqrt(float4& v, float eps) +{ + constexpr float kInvDim = 1.0F / static_cast(Dim); + v.x = rsqrtf((v.x * kInvDim) + eps); + v.y = rsqrtf((v.y * kInvDim) + eps); + v.z = rsqrtf((v.z * kInvDim) + eps); + v.w = rsqrtf((v.w * kInvDim) + eps); + return v; +} + __device__ __forceinline__ float4 ld_global_volatile(float4* addr) { float4 val; @@ -162,6 +181,44 @@ __device__ __forceinline__ void blockReduceSumRange(T* val, int rangeStart, int } } +// from sglang python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +template +__device__ __forceinline__ void local_warp_reduce_sum(T& value, uint32_t active_mask = 0xffffffffu) +{ + static_assert(kNumThreads >= 1 && kNumThreads <= MINIMAX_REDUCE_RMS_WARP_SIZE); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) + { + value += __shfl_xor_sync(active_mask, value, mask, MINIMAX_REDUCE_RMS_WARP_SIZE); + } +} + +// for float4 version +template +__device__ __forceinline__ void local_warp_reduce_sum_array(T* value_ptr, uint32_t active_mask = 0xffffffffu) +{ + static_assert(kNumThreads >= 1 && kNumThreads <= MINIMAX_REDUCE_RMS_WARP_SIZE); +#pragma unroll + for (int i = 0; i < ArraySize; ++i) + { +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) + { + value_ptr[i] += __shfl_xor_sync(active_mask, value_ptr[i], mask, MINIMAX_REDUCE_RMS_WARP_SIZE); + } + } +} + +constexpr int next_pow2(int val) +{ + int result = 1; + while (result < val) + { + result <<= 1; + } + return result; +} + template class IndexHelper { @@ -216,7 +273,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa int tot_tokens = params.size_q / params.hidden_dim; float4 clear_vec = get_neg_zero(); // FusedOp fused_op(params, access_id, access_id_in_token); - + __shared__ float shared_vars_all_ranks; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); if constexpr (!TriggerCompletionAtEnd) @@ -250,33 +307,39 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa // we only need the first thread to push data to other ranks if (threadIdx.x == 0) { +#pragma unroll for (int r = 0; r < NRanks; ++r) { // temp data buffer [nranks, total_tokens, 1] reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_tokens) + token_id] = (sum_variance); } - } + // we only use the first thread to pull data from other ranks + bool done = false; + float vals_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vals_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_tokens) + token_id]); + done &= !is_neg_zero(vals_all_ranks[r]); + } + } - // Load data from other ranks - bool done = false; - float vars_all_ranks[NRanks]; - while (!done) - { - done = true; + sum_variance = 0.F; #pragma unroll for (int r = 0; r < NRanks; ++r) { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_tokens) + token_id]); - done &= !is_neg_zero(vars_all_ranks[r]); + sum_variance += vals_all_ranks[r]; } + sum_variance = sqrtf(sum_variance / NRanks / static_cast(params.hidden_dim) + params.rms_eps); + shared_vars_all_ranks = sum_variance; } - sum_variance = 0.F; -#pragma unroll - for (int r = 0; r < NRanks; ++r) - { - sum_variance += vars_all_ranks[r]; - } + + __syncthreads(); + sum_variance = shared_vars_all_ranks; // step 3: calculate the rms norm (input * rsqrt(variance + eps)) @@ -289,9 +352,8 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - vals[i] = static_cast(static_cast(vals[i]) - * rsqrtf((sum_variance / static_cast(params.hidden_dim) / NRanks) + params.rms_eps) - * static_cast(norm_weight[i])); + vals[i] + = static_cast(static_cast(vals[i]) * sum_variance * static_cast(norm_weight[i])); } // step 4: store the rms norm @@ -317,15 +379,22 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa * When tot_tokens % 4 != 0, the last group pads rows with zeros; padded rows are not written to rms_norm_out. * IsQK: when true, process Q+K in one loop with doubled comm buffer; when false, single-matrix (Q only). */ -template -__global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) +template +__global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) { - int tot_tokens = params.size_q / params.hidden_dim; - int tot_groups = (tot_tokens + 3) / 4; // ceiling: last group may have 1-3 valid rows - int access_per_row_q = params.hidden_dim / kElemsPerAccess; - int access_per_row_k = IsQK ? (params.hidden_dim_k / kElemsPerAccess) : 0; - int q_warps = (access_per_row_q + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; - int k_warps = IsQK ? ((access_per_row_k + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE) : 0; + static_assert(TokenPerBlock == 1 || TokenPerBlock == 4, "TokenPerBlock must be 1 or 4"); + constexpr int RankQDim = OriginQDim / NRanks; + constexpr int RankKDim = OriginKDim / NRanks; + constexpr int ThreadsPerRowQ = RankQDim / kElemsPerAccess; + constexpr int ThreadsPerRowK = RankKDim / kElemsPerAccess; + constexpr int NumWarpQ = (ThreadsPerRowQ + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; + constexpr int NumWarpK = (ThreadsPerRowK + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; + int tot_tokens = params.size_q / RankQDim; + int tot_groups = (tot_tokens + TokenPerBlock - 1) / TokenPerBlock; // ceiling: last group may have 1-3 valid rows + + using AccumType = std::conditional_t; + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace cg = cooperative_groups; cg::cluster_group cluster = cg::this_cluster(); @@ -338,11 +407,16 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 int access_id_in_token = threadIdx.x; int group_stride = gridDim.x; #endif - bool is_q = (access_id_in_token < q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE); - int k_thread_idx = IsQK ? (access_id_in_token - q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE) : 0; - bool is_valid_token = is_q ? (access_id_in_token < access_per_row_q) : (k_thread_idx < access_per_row_k); + bool is_q = (access_id_in_token < NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE); + int q_thread_idx = access_id_in_token; + int k_thread_idx = (access_id_in_token - (NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE)); + bool is_valid_token = is_q ? (access_id_in_token < ThreadsPerRowQ) : (k_thread_idx < ThreadsPerRowK); float4 clear_vec = get_neg_zero(); + __shared__ float block_reduce_sum[TokenPerBlock][MINIMAX_REDUCE_RMS_WARP_SIZE + 1]; // 33 > warpQ + warpK + __shared__ float global_scale_q[TokenPerBlock]; + __shared__ float global_scale_k[TokenPerBlock]; + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); if constexpr (!TriggerCompletionAtEnd) @@ -352,239 +426,247 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 #endif LamportComm comm(params.workspace, params.rank); + // first step load rms params scale + __nv_bfloat16 norm_weight[kElemsPerAccess]{}; + if (access_id_in_token < NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE) // Q branch + { + // load rms params scale + if (is_valid_token) + { + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma)[access_id_in_token]; + } + } + else // K branch + { + // load rms params scale + if (is_valid_token) + { + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma_k)[k_thread_idx]; + } + } + for (int g = group_id; g < tot_groups; g += group_stride) { - alignas(16) DType vals[4][kElemsPerAccess]{}; - float sum_variance[4] = {0.F, 0.F, 0.F, 0.F}; - float sum_variance_k[4] = {0.F, 0.F, 0.F, 0.F}; + alignas(16) DType vals[TokenPerBlock][kElemsPerAccess]{}; + float warp_sum_variance[TokenPerBlock]{0.F}; if (is_q) { -// Q branch: each thread only covers 128bit + // Q branch: each thread only covers 128bit * TokenPerBlock #pragma unroll - for (int r = 0; r < 4; ++r) + for (int row = 0; row < TokenPerBlock; ++row) { - int token_r = g * 4 + r; + int token_r = (g * TokenPerBlock) + row; if (token_r >= tot_tokens || (!is_valid_token)) { continue; } - int idx_r = token_r * access_per_row_q + access_id_in_token; - *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; + int idx_r = (token_r * ThreadsPerRowQ) + access_id_in_token; + *reinterpret_cast(&vals[row][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - float x = static_cast(vals[r][i]); - sum_variance[r] += x * x; + auto x = static_cast(vals[row][i]); + warp_sum_variance[row] += x * x; } } } - else if constexpr (IsQK) // k branch + else // k branch { // K branch: k_thread_idx = threadIdx.x - q_warps, each thread covers 32 K columns #pragma unroll - for (int r = 0; r < 4; ++r) + for (int row = 0; row < TokenPerBlock; ++row) { - int token_r = g * 4 + r; - if (token_r >= tot_tokens || k_thread_idx >= access_per_row_k) + int token_r = (g * TokenPerBlock) + row; + if (token_r >= tot_tokens || (!is_valid_token)) { continue; } - int idx_r = token_r * access_per_row_k + k_thread_idx; - *reinterpret_cast(&vals[r][0]) = reinterpret_cast(params.allreduce_in_k)[idx_r]; + int idx_r = (token_r * ThreadsPerRowK) + k_thread_idx; + *reinterpret_cast(&vals[row][0]) + = reinterpret_cast(params.allreduce_in_k)[idx_r]; #pragma unroll for (int i = 0; i < kElemsPerAccess; ++i) { - float x = static_cast(vals[r][i]); - sum_variance_k[r] += x * x; + auto x = static_cast(vals[row][i]); + warp_sum_variance[row] += x * x; } } } - // Local reduce: only Q segment contributes to sum_variance, only K segment to sum_variance_k - // here we use all threads to reduce sum_variance and sum_variance_k - // TODO: we can do local reduce only within q threads and k threads respectively - tensorrt_llm::common::blockReduceSumV2(sum_variance); - if constexpr (IsQK) + // Local warp reduce: + // here we use all threads to reduce warp_sum_variance + local_warp_reduce_sum_array(warp_sum_variance); + // each warp write the warp reduce result to the shared memory + int line = threadIdx.x & (MINIMAX_REDUCE_RMS_WARP_SIZE - 1); + if (line == 0) { - int const kStartThread = q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE; - int const kEndThread = (q_warps + k_warps) * MINIMAX_REDUCE_RMS_WARP_SIZE; - blockReduceSumRange(sum_variance_k, kStartThread, kEndThread); - } - #pragma unroll - for (int r = 0; r < 4; ++r) - { - if (is_neg_zero(sum_variance[r])) + for (int _ = 0; _ < TokenPerBlock; ++_) { - sum_variance[r] = 0.F; - } - if constexpr (IsQK) - { - if (is_neg_zero(sum_variance_k[r])) - { - sum_variance_k[r] = 0.F; - } + block_reduce_sum[_][threadIdx.x / MINIMAX_REDUCE_RMS_WARP_SIZE] = warp_sum_variance[_]; } } + __syncthreads(); + int tid = threadIdx.x; + // then two warps process q block reduce and k block reduce respectively - // Allreduce: write float4(s) to comm (thread 0 has both after broadcast) - if (threadIdx.x == 0 || threadIdx.x == q_warps * MINIMAX_REDUCE_RMS_WARP_SIZE) + if (tid < MINIMAX_REDUCE_RMS_WARP_SIZE) { - if (is_q) + constexpr int kNumWarpQPow2 = next_pow2(NumWarpQ) > NRanks ? next_pow2(NumWarpQ) : NRanks; + float local_sum[TokenPerBlock]; +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + local_sum[_] = tid < NumWarpQ ? block_reduce_sum[_][tid] : 0.F; + } + local_warp_reduce_sum_array(local_sum); + // for thread [0, NRanks), we need to push data to comm buffer + if (tid < NRanks) { - float4 sum4; - sum4.x = sum_variance[0]; - sum4.y = sum_variance[1]; - sum4.z = sum_variance[2]; - sum4.w = sum_variance[3]; #pragma unroll - for (int r = 0; r < NRanks; ++r) + for (int _ = 0; _ < TokenPerBlock; ++_) { - if constexpr (IsQK) - { - reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g] = sum4; - } - else + if (is_neg_zero(local_sum[_])) { - reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_groups) + g] = sum4; + local_sum[_] = 0.F; } } - } - else if constexpr (IsQK) - { - float4 sum4; - sum4.x = sum_variance_k[0]; - sum4.y = sum_variance_k[1]; - sum4.z = sum_variance_k[2]; - sum4.w = sum_variance_k[3]; -#pragma unroll - for (int r = 0; r < NRanks; ++r) + // push data to comm buffer, for each thread, we only need to push data to one rank + + reinterpret_cast(comm.data_bufs[tid])[(params.rank * tot_groups * 2) + (2 * g)] + = *reinterpret_cast(local_sum); + // pull data from other ranks + bool done = false; + AccumType var_all_ranks; + while (!done) + { + done = true; + var_all_ranks = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(tid * tot_groups * 2) + (2 * g)]); + done &= !is_neg_zero(var_all_ranks); + } + // local reduce + constexpr uint32_t kActiveMask = (1 << NRanks) - 1; + local_warp_reduce_sum_array( + reinterpret_cast(&var_all_ranks), kActiveMask); + if (tid == 0) { - reinterpret_cast(comm.data_bufs[r])[(params.rank * 2 * tot_groups) + 2 * g + 1] = sum4; + *reinterpret_cast(global_scale_q) + = rms_rsqrt(var_all_ranks, params.rms_eps); } } } - - // Read Q from buffer, sum, then RMS and store Q - bool done = false; - float4 vars_all_ranks[NRanks]; - if (is_q) + // k branch + else if (threadIdx.x >= MINIMAX_REDUCE_RMS_WARP_SIZE * NumWarpQ + && threadIdx.x < MINIMAX_REDUCE_RMS_WARP_SIZE * (NumWarpQ + 1)) { - while (!done) + constexpr int kNumWarpKPow2 = next_pow2(NumWarpK) > NRanks ? next_pow2(NumWarpK) : NRanks; + float local_sum[TokenPerBlock]; +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + local_sum[_] = k_thread_idx < NumWarpK ? block_reduce_sum[_][NumWarpQ + k_thread_idx] : 0.F; + } + local_warp_reduce_sum_array(local_sum); + // for thread [0, NRanks), we need to push data to comm buffer + if (k_thread_idx < NRanks) { - done = true; #pragma unroll - for (int r = 0; r < NRanks; ++r) + for (int _ = 0; _ < TokenPerBlock; ++_) { - if constexpr (IsQK) + if (is_neg_zero(local_sum[_])) { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g]); + local_sum[_] = 0.F; } - else - { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_groups) + g]); - } - done &= !is_neg_zero(vars_all_ranks[r]); } - } - } - else if constexpr (IsQK) - { - while (!done) - { - done = true; - for (int r = 0; r < NRanks; ++r) + // push data to comm buffer, for each thread, we only need to push data to one rank + reinterpret_cast(comm.data_bufs[k_thread_idx])[(params.rank * tot_groups * 2) + (2 * g + 1)] + = *reinterpret_cast(local_sum); + // pull data from other ranks + bool done = false; + AccumType var_all_ranks; + while (!done) { - vars_all_ranks[r] = ld_global_volatile( - &reinterpret_cast(comm.data_bufs[params.rank])[(r * 2 * tot_groups) + 2 * g + 1]); - done &= !is_neg_zero(vars_all_ranks[r]); + done = true; + var_all_ranks = ld_global_volatile(&reinterpret_cast( + comm.data_bufs[params.rank])[(k_thread_idx * tot_groups * 2) + (2 * g + 1)]); + done &= !is_neg_zero(var_all_ranks); + } + // local reduce + constexpr uint32_t kActiveMask = (1 << NRanks) - 1; + local_warp_reduce_sum_array( + reinterpret_cast(&var_all_ranks), kActiveMask); + if (k_thread_idx == 0) + { + *reinterpret_cast(global_scale_k) + = rms_rsqrt(var_all_ranks, params.rms_eps); } } } - - sum_variance[0] = 0.F; - sum_variance[1] = 0.F; - sum_variance[2] = 0.F; - sum_variance[3] = 0.F; -#pragma unroll - for (int r = 0; r < NRanks; ++r) - { - sum_variance[0] += vars_all_ranks[r].x; - sum_variance[1] += vars_all_ranks[r].y; - sum_variance[2] += vars_all_ranks[r].z; - sum_variance[3] += vars_all_ranks[r].w; - } - - // RMS norm and store 4 rows of Q (Q branch only, reload and store per column) + __syncthreads(); + // final part if (is_q) { - if (access_id_in_token < access_per_row_q) +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) { - - __nv_bfloat16 norm_weight[kElemsPerAccess]; - *reinterpret_cast::norm_weight_type*>(norm_weight) - = reinterpret_cast::norm_weight_type const*>( - params.rms_gamma)[access_id_in_token]; + warp_sum_variance[_] = global_scale_q[_]; + } #pragma unroll - for (int r = 0; r < 4; ++r) - { - int token_r = g * 4 + r; - if (token_r >= tot_tokens) - { - continue; - } - float scale - = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim) / NRanks) + params.rms_eps); - + for (int r = 0; r < TokenPerBlock; ++r) + { #pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) - { - vals[r][i] = static_cast( - static_cast(vals[r][i]) * scale * static_cast(norm_weight[i])); - } - int idx_out = token_r * access_per_row_q + access_id_in_token; - reinterpret_cast(params.rms_norm_out)[idx_out] = *reinterpret_cast(&vals[r][0]); + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] = static_cast( + static_cast(vals[r][i]) * warp_sum_variance[r] * static_cast(norm_weight[i])); + } + // store to rms_norm_out + int token_r = (g * TokenPerBlock) + r; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; } + int idx_r = (token_r * ThreadsPerRowQ) + access_id_in_token; + reinterpret_cast(params.rms_norm_out)[idx_r] = *reinterpret_cast(&vals[r][0]); } } - else if constexpr (IsQK) + else { - - if (k_thread_idx < access_per_row_k) +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) { - __nv_bfloat16 norm_weight_k[kElemsPerAccess]; - *reinterpret_cast::norm_weight_type*>(norm_weight_k) - = reinterpret_cast::norm_weight_type const*>( - params.rms_gamma_k)[k_thread_idx]; + warp_sum_variance[_] = global_scale_k[_]; + } #pragma unroll - for (int r = 0; r < 4; ++r) - { - int token_r = g * 4 + r; - if (token_r >= tot_tokens) - { - continue; - } - float scale_k - = rsqrtf((sum_variance[r] / static_cast(params.hidden_dim_k) / NRanks) + params.rms_eps); + for (int r = 0; r < TokenPerBlock; ++r) + { #pragma unroll - for (int i = 0; i < kElemsPerAccess; ++i) - { - vals[r][i] = static_cast( - static_cast(vals[r][i]) * scale_k * static_cast(norm_weight_k[i])); - } - int idx_out = token_r * access_per_row_k + k_thread_idx; - reinterpret_cast(params.rms_norm_out_k)[idx_out] = *reinterpret_cast(&vals[r][0]); + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] = static_cast( + static_cast(vals[r][i]) * warp_sum_variance[r] * static_cast(norm_weight[i])); + } + // store to rms_norm_out + int token_r = (g * TokenPerBlock) + r; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; } + int idx_r = (token_r * ThreadsPerRowK) + k_thread_idx; + reinterpret_cast(params.rms_norm_out_k)[idx_r] = *reinterpret_cast(&vals[r][0]); } } } // Clear comm buffer - int clear_access = static_cast(comm.clear_size / kElemsPerAccess); + int clear_access = static_cast(comm.clear_size / (sizeof(float4) / sizeof(DType))); int clear_stride = group_stride * blockDim.x; for (int idx = group_id * blockDim.x + threadIdx.x; idx < clear_access; idx += clear_stride) @@ -592,7 +674,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport_float4 reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - comm.update(IsQK ? (2 * tot_groups * 8 * NRanks) : (tot_groups * 8 * NRanks)); + comm.update((2 * tot_groups * TokenPerBlock * sizeof(float) / sizeof(DType) * NRanks)); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { @@ -645,7 +727,6 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) attribute[1].val.clusterDim.z = 1; cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; - bool trigger_completion_at_end = params.trigger_completion_at_end; if (trigger_completion_at_end) { @@ -657,7 +738,7 @@ void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) } } -template +template void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& params) { TLLM_CHECK(params.size_q % params.hidden_dim == 0); @@ -708,26 +789,26 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par { if (is_qk) { - TLLM_CUDA_CHECK( - cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); } else { - TLLM_CUDA_CHECK( - cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); } } else { if (is_qk) { - TLLM_CUDA_CHECK( - cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); } else { - TLLM_CUDA_CHECK(cudaLaunchKernelEx( - &cfg, minimax_reduce_rms_kernel_lamport_float4, params)); + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); } } } @@ -735,13 +816,14 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par template void dispatch_dtype(MiniMaxReduceRMSParams const& params) { - bool use_float4 = true; + bool use_float4 = (params.allreduce_in_k != nullptr) && (params.hidden_dim * params.nranks == 6144) + && (params.hidden_dim_k * params.nranks == 1024); if (params.dtype == nvinfer1::DataType::kHALF) { if (use_float4) { - minimax_reduce_rms_kernel_launcher_float4(params); + minimax_reduce_rms_kernel_launcher_float4(params); } else { @@ -752,7 +834,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { if (use_float4) { - minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks>(params); + minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144, 1024>(params); } else { @@ -763,7 +845,7 @@ void dispatch_dtype(MiniMaxReduceRMSParams const& params) { if (use_float4) { - minimax_reduce_rms_kernel_launcher_float4(params); + minimax_reduce_rms_kernel_launcher_float4(params); } else { diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index ddedfb6b30cd..581656eb54c5 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1843,6 +1843,7 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor co torch::Tensor rms_norm_out = torch::empty_like(input); allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr(); + allreduce_params.trigger_completion_at_end = trigger_completion_at_end_; tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(allreduce_params); diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index b7d84fc1e03d..892a916f7fb5 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1193,7 +1193,7 @@ def forward(self, input: torch.Tensor, rms_weights: torch.Tensor, self.workspace, self.mapping.tp_rank, self.mapping.tp_size, eps, - False) + True) def forward_qk(self, q: torch.Tensor, k: torch.Tensor, rms_weights_q: torch.Tensor, rms_weights_k: torch.Tensor, @@ -1208,6 +1208,6 @@ def forward_qk(self, q: torch.Tensor, k: torch.Tensor, self.mapping.tp_rank, self.mapping.tp_size, eps, - False, + True, ) return (out_list[0], out_list[1]) diff --git a/tests/microbenchmarks/minimax_all_reduce.py b/tests/microbenchmarks/minimax_all_reduce.py index 7c7e86ee5768..21c4c8739a03 100644 --- a/tests/microbenchmarks/minimax_all_reduce.py +++ b/tests/microbenchmarks/minimax_all_reduce.py @@ -29,78 +29,59 @@ import tensorrt_llm as tllm from tensorrt_llm import Mapping from tensorrt_llm._torch.distributed import MiniMaxAllReduceRMS -from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, nvtx_range -from tensorrt_llm.bindings.internal.runtime import delay_kernel +from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, mpi_barrier from tensorrt_llm.logger import logger from tensorrt_llm.plugin.plugin import CustomAllReduceHelper # MiniMax all-reduce only uses D (hidden_size) 128 and 1536 in practice. -ALLOWED_HIDDEN_SIZES = (128, 1536) +ALLOWED_HIDDEN_SIZES = (256, 1536) # Q+K fused API benchmark dimensions QK_Q_DIM = 1536 -QK_K_DIM = 128 +QK_K_DIM = 256 def profile_minimax_allreduce_rms( mapping: Mapping, op: MiniMaxAllReduceRMS, - enable_cudagraph: bool = False, - inner_loop: int = 200, - outer_loop: int = 10, + warmup: int = 10, + iters: int = 100, + inner_loop: int = 8, input_tensor=None, norm_weight=None, eps: float = 1e-5, ): - def func(loop_num=inner_loop): - out = None - for _ in range(loop_num): - out = op(input_tensor, norm_weight, eps) - return out - - start = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] - stop = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] - graph = torch.cuda.CUDAGraph() - - stream = torch.cuda.Stream() - with ( - torch.cuda.stream(stream), - nvtx_range(f"minimax_allreduce_rms: shape={input_tensor.size(0)}x{input_tensor.size(1)}"), - ): - func(loop_num=1) - - if enable_cudagraph: - for i in range(2): - func(loop_num=1) - with torch.cuda.graph(graph, stream=stream): - _ = func() - - delay_kernel(20000, stream) + def func(): + for _ in range(inner_loop): + op(input_tensor, norm_weight, eps) - torch.cuda.synchronize() - torch.cuda.profiler.start() - - for i in range(outer_loop): - start[i].record(stream) - if enable_cudagraph: - graph.replay() - else: - _ = func() - stop[i].record(stream) + for _ in range(warmup): + for i in range(inner_loop): + op(input_tensor, norm_weight, eps) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + func() + + graph.replay() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + graph.replay() + start.record() + for _ in range(iters): + graph.replay() + end.record() torch.cuda.synchronize() - torch.cuda.profiler.stop() - runtimes = [start[i].elapsed_time(stop[i]) for i in range(outer_loop)] - median_ms = sorted(runtimes)[len(runtimes) // 2] / inner_loop - return median_ms + return start.elapsed_time(end) * 1000.0 / (iters * inner_loop) def profile_minimax_allreduce_rms_qk( mapping: Mapping, op: MiniMaxAllReduceRMS, - enable_cudagraph: bool = False, - inner_loop: int = 200, - outer_loop: int = 10, + warmup: int = 10, + iters: int = 100, + inner_loop: int = 8, q_tensor=None, k_tensor=None, norm_weight_q=None, @@ -109,56 +90,37 @@ def profile_minimax_allreduce_rms_qk( ): """Profile the fused Q+K minimax allreduce RMS API (forward_qk).""" - def func(loop_num=inner_loop): - out_q, out_k = None, None - for _ in range(loop_num): - out_q, out_k = op.forward_qk(q_tensor, k_tensor, norm_weight_q, norm_weight_k, eps) - return (out_q, out_k) + def func(): + for _ in range(inner_loop): + op.forward_qk(q_tensor, k_tensor, norm_weight_q, norm_weight_k, eps) - start = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] - stop = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] - graph = torch.cuda.CUDAGraph() - - stream = torch.cuda.Stream() - n_tok, q_d, k_d = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1) - with ( - torch.cuda.stream(stream), - nvtx_range(f"minimax_allreduce_rms_qk: shape={n_tok}x{q_d}+{n_tok}x{k_d}"), - ): - func(loop_num=1) - - if enable_cudagraph: - for i in range(2): - func(loop_num=1) - with torch.cuda.graph(graph, stream=stream): - _ = func() - - delay_kernel(20000, stream) - - torch.cuda.synchronize() - torch.cuda.profiler.start() - - for i in range(outer_loop): - start[i].record(stream) - if enable_cudagraph: - graph.replay() - else: - _ = func() - stop[i].record(stream) + for _ in range(warmup): + func() + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + func() + + graph.replay() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + graph.replay() + start.record() + for _ in range(iters): + graph.replay() + end.record() torch.cuda.synchronize() - torch.cuda.profiler.stop() - runtimes = [start[i].elapsed_time(stop[i]) for i in range(outer_loop)] - median_ms = sorted(runtimes)[len(runtimes) // 2] / inner_loop - return median_ms + return start.elapsed_time(end) * 1000.0 / (iters * inner_loop) def minimax_allreduce_benchmark( dtype: str = "bfloat16", test_range: str = "256,256000000,10", - enable_cudagraph: bool = False, explore_2d: bool = False, save_csv: str = None, + warmup: int = 10, + iters: int = 100, ): world_size = tllm.mpi_world_size() rank = tllm.mpi_rank() @@ -176,8 +138,7 @@ def minimax_allreduce_benchmark( torch_dtype = tllm._utils.str_dtype_to_torch(dtype) - inner_loop = 200 - outer_loop = 10 + inner_loop = 8 eps = 1e-5 shape_list = [] @@ -213,12 +174,13 @@ def minimax_allreduce_benchmark( input_tensor = torch.ones((num_tokens, hidden_size), dtype=torch_dtype, device="cuda") norm_weight = torch.randn((hidden_size,), dtype=torch_dtype, device="cuda") - median_ms = profile_minimax_allreduce_rms( + mpi_barrier() + median_us = profile_minimax_allreduce_rms( mapping=mapping, op=op, - enable_cudagraph=enable_cudagraph, + warmup=warmup, + iters=iters, inner_loop=inner_loop, - outer_loop=outer_loop, input_tensor=input_tensor, norm_weight=norm_weight, eps=eps, @@ -238,15 +200,12 @@ def minimax_allreduce_benchmark( "hidden_size": [hidden_size], "q_dim": [pd.NA], "k_dim": [pd.NA], - "time (us)": [median_ms * 1000], + "time (us)": [median_us], } ), ] ) - print( - f"num_tokens: {num_tokens}, hidden_size: {hidden_size}, " - f"time (us): {median_ms * 1000}" - ) + print(f"num_tokens: {num_tokens}, hidden_size: {hidden_size}, time (us): {median_us}") # Q+K fused API benchmark: q_dim=1536, k_dim=128 num_tokens_qk = sorted({n for n, _ in shape_list}) @@ -261,12 +220,13 @@ def minimax_allreduce_benchmark( if message_size_bytes_qk > max_workspace: continue - median_ms_qk = profile_minimax_allreduce_rms_qk( + mpi_barrier() + median_us_qk = profile_minimax_allreduce_rms_qk( mapping=mapping, op=op, - enable_cudagraph=enable_cudagraph, + warmup=warmup, + iters=iters, inner_loop=inner_loop, - outer_loop=outer_loop, q_tensor=q_tensor, k_tensor=k_tensor, norm_weight_q=norm_weight_q, @@ -288,14 +248,14 @@ def minimax_allreduce_benchmark( "hidden_size": [pd.NA], "q_dim": [QK_Q_DIM], "k_dim": [QK_K_DIM], - "time (us)": [median_ms_qk * 1000], + "time (us)": [median_us_qk], } ), ] ) print( f"qk: num_tokens: {num_tokens}, q_dim: {QK_Q_DIM}, k_dim: {QK_K_DIM}, " - f"time (us): {median_ms_qk * 1000}" + f"time (us): {median_us_qk}" ) if mapping.rank == 0: @@ -321,15 +281,17 @@ def minimax_allreduce_benchmark( help="min_size,max_size,multiplicative_ratio", ) parser.add_argument("--explore_2d", action="store_true", default=False) - parser.add_argument("--enable_cudagraph", action="store_true") parser.add_argument("--save_csv", type=str, default=None) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=100) args = parser.parse_args() minimax_allreduce_benchmark( args.dtype, args.range, - args.enable_cudagraph, args.explore_2d, args.save_csv, + args.warmup, + args.iters, ) From 212c46ad54a81e49792a9e0051405c30b991e67f Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:20:47 +0800 Subject: [PATCH 15/20] chore: modify rms norm weight loading method Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_minimaxm2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index b496febacdfb..ef107c1dafa4 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -30,7 +30,7 @@ from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import MiniMaxM2MoeRoutingMethod, create_moe -from ..modules.linear import Linear +from ..modules.linear import Linear, TensorParallelMode, copy_weight, load_weight_shard from ..modules.rms_norm import RMSNorm from ..utils import AuxStreamType from .modeling_utils import DecoderModel, DecoderModelForCausalLM, register_auto_model @@ -132,13 +132,15 @@ def __init__( self.minimax_all_reduce_rms = MiniMaxAllReduceRMS(mapping=self.mapping) - # TODO: add load weights method - def load_weights(self, weights: Dict): + def load_weights(self, weights: List[Dict]): assert len(weights) == 1 - slice_width = self.hidden_size - slice_start = self.mapping.tp_rank * slice_width - slice_end = slice_start + slice_width - self.weight.copy_(weights[0]["weight"][slice_start:slice_end].to(self.weight.dtype)) + weight = load_weight_shard( + weights[0]["weight"], + tensor_parallel_size=self.mapping.tp_size, + tensor_parallel_rank=self.mapping.tp_rank, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + copy_weight(self.weight, weight) def forward(self, hidden_states: torch.Tensor): """ From 76db2c33a2f1155431c5fb5a3c112ee28a3b8521 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:31:22 +0800 Subject: [PATCH 16/20] fix: fix split tensor contiguous bug Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 4 +- .../_torch/models/modeling_minimaxm2.py | 2 + .../_torch/multi_gpu/test_allreduce.py | 40 ++++++++++++++----- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 78697f5cc74c..4f51ad6267ca 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -334,7 +334,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa { sum_variance += vals_all_ranks[r]; } - sum_variance = sqrtf(sum_variance / NRanks / static_cast(params.hidden_dim) + params.rms_eps); + sum_variance = rsqrtf(sum_variance / NRanks / static_cast(params.hidden_dim) + params.rms_eps); shared_vars_all_ranks = sum_variance; } @@ -364,7 +364,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMa // Clear comm buffer that previous kernel used reinterpret_cast(comm.clear_buf)[idx] = clear_vec; } - comm.update(params.size_q * NRanks); + comm.update(tot_tokens * NRanks * sizeof(float) / sizeof(DType)); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index ef107c1dafa4..a51edce83291 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -220,6 +220,8 @@ def __init__( def apply_qk_norm(self, q, k): if self.qkv_proj.mapping.tp_size > 1: + q = q.contiguous() + k = k.contiguous() q, k = self.q_norm.minimax_all_reduce_rms.forward_qk( q, k, self.q_norm.weight, self.k_norm.weight, self.q_norm.eps ) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index 91736d471f9a..7b6f72e13553 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -780,12 +780,10 @@ def test_minimax_allreduce_rms(mpi_pool_executor): @torch.inference_mode() -def run_minimax_allreduce_rms_qk_op(q_input: torch.Tensor, - k_input: torch.Tensor, - tensor_parallel_size: int, - tensor_parallel_rank: int, - rms_weights_q: torch.Tensor, - rms_weights_k: torch.Tensor, eps: float): +def run_minimax_allreduce_rms_qk_op( + q_input: torch.Tensor, k_input: torch.Tensor, tensor_parallel_size: int, + tensor_parallel_rank: int, rms_weights_q: torch.Tensor, + rms_weights_k: torch.Tensor, eps: float, non_contiguous_input: bool): torch.manual_seed(42) num_tokens = q_input.shape[0] @@ -828,6 +826,22 @@ def run_minimax_allreduce_rms_qk_op(q_input: torch.Tensor, k_input = k_input.reshape(num_tokens, tensor_parallel_size, -1) rank_q_input = q_input[:, tensor_parallel_rank, :].contiguous() rank_k_input = k_input[:, tensor_parallel_rank, :].contiguous() + if non_contiguous_input: + # Mimic the integration path where q and k are views split from + # the local qkv shard. + rank_v_input = torch.zeros_like(rank_k_input) + rank_qkv_input = torch.cat([rank_q_input, rank_k_input, rank_v_input], + dim=-1) + rank_q_input, rank_k_input, _ = rank_qkv_input.split( + [ + rank_q_input.shape[-1], + rank_k_input.shape[-1], + rank_v_input.shape[-1], + ], + dim=-1, + ) + assert not rank_q_input.is_contiguous() + assert not rank_k_input.is_contiguous() # rms weights should be sliced by rank rms_weights_q = rms_weights_q.reshape(tensor_parallel_size, -1) @@ -866,20 +880,24 @@ def run_minimax_allreduce_rms_qk_op(q_input: torch.Tensor, def run_minimax_allreduce_rms_qk_single_rank(tensor_parallel_size, single_rank_forward_func, q_input, k_input, rms_weights_q, - rms_weights_k, eps): + rms_weights_k, eps, + non_contiguous_input): rank = tensorrt_llm.mpi_rank() torch.cuda.set_device(rank) try: single_rank_forward_func(q_input, k_input, tensor_parallel_size, rank, - rms_weights_q, rms_weights_k, eps) + rms_weights_q, rms_weights_k, eps, + non_contiguous_input) except Exception: traceback.print_exc() raise return True +@pytest.mark.parametrize("non_contiguous_input", [False, True], + ids=["contiguous", "split_qkv_view"]) @pytest.mark.parametrize("mpi_pool_executor", [4], indirect=True) -def test_minimax_allreduce_rms_qk(mpi_pool_executor): +def test_minimax_allreduce_rms_qk(mpi_pool_executor, non_contiguous_input): torch.manual_seed(42) seq_len = 1024 @@ -897,8 +915,8 @@ def test_minimax_allreduce_rms_qk(mpi_pool_executor): results = mpi_pool_executor.map( run_minimax_allreduce_rms_qk_single_rank, *zip(*[(tensor_parallel_size, run_minimax_allreduce_rms_qk_op, q_input, - k_input, rms_weights_q, rms_weights_k, eps)] * - tensor_parallel_size), + k_input, rms_weights_q, rms_weights_k, eps, + non_contiguous_input)] * tensor_parallel_size), ) for r in results: assert r is True From 5049a8e91ec756a742b645b18ca33483c447a3ce Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:49:29 +0800 Subject: [PATCH 17/20] chore: add more torch check for input shape Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- cpp/tensorrt_llm/thop/allreduceOp.cpp | 20 +++++++++++++++++++ .../_torch/models/modeling_minimaxm2.py | 1 + .../_torch/multi_gpu/test_allreduce.py | 3 +++ 3 files changed, 24 insertions(+) diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 581656eb54c5..686d6da7441d 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1827,6 +1827,13 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor co torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps, bool const trigger_completion_at_end_) { + TORCH_CHECK(input.dim() == 2, "minimax_allreduce_rms: input must be 2D"); + TORCH_CHECK(norm_weight.dim() == 1, "minimax_allreduce_rms: norm_weight must be 1D"); + TORCH_CHECK( + input.size(-1) == norm_weight.size(0), "minimax_allreduce_rms: input hidden dim must match norm_weight"); + TORCH_CHECK(input.is_contiguous(), "minimax_allreduce_rms: input must be contiguous"); + TORCH_CHECK(norm_weight.is_contiguous(), "minimax_allreduce_rms: norm_weight must be contiguous"); + auto allreduce_params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); allreduce_params.nranks = static_cast(nranks); @@ -1854,12 +1861,25 @@ std::vector minimax_allreduce_rms_qk(torch::Tensor const& q, torc torch::Tensor const& norm_weight_q, torch::Tensor const& norm_weight_k, torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps, bool const trigger_completion_at_end_) { + int64_t constexpr kSupportedGlobalHeadDimQ = 6144; + int64_t constexpr kSupportedGlobalHeadDimK = 1024; + TORCH_CHECK(q.scalar_type() == k.scalar_type(), "minimax_allreduce_rms_qk: q and k must have same dtype"); TORCH_CHECK(q.dim() == 2 && k.dim() == 2, "minimax_allreduce_rms_qk: q and k must be 2D"); TORCH_CHECK(q.size(0) == k.size(0), "minimax_allreduce_rms_qk: q and k must have same num_token"); + TORCH_CHECK(q.is_contiguous(), "minimax_allreduce_rms_qk: q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "minimax_allreduce_rms_qk: k must be contiguous"); + TORCH_CHECK(norm_weight_q.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_q must be 1D"); + TORCH_CHECK(norm_weight_k.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_k must be 1D"); + TORCH_CHECK(norm_weight_q.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_q must be contiguous"); + TORCH_CHECK(norm_weight_k.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_k must be contiguous"); int64_t head_dim_q = q.size(-1); int64_t head_dim_k = k.size(-1); TORCH_CHECK(head_dim_q >= head_dim_k, "minimax_allreduce_rms_qk: head_dim_q must be >= head_dim_k"); + TORCH_CHECK(head_dim_q == norm_weight_q.size(0), "minimax_allreduce_rms_qk: q hidden dim must match norm_weight_q"); + TORCH_CHECK(head_dim_k == norm_weight_k.size(0), "minimax_allreduce_rms_qk: k hidden dim must match norm_weight_k"); + TORCH_CHECK((head_dim_q * nranks) == kSupportedGlobalHeadDimQ && (head_dim_k * nranks) == kSupportedGlobalHeadDimK, + "minimax_allreduce_rms_qk: only global q/k dims 6144/1024 are currently supported"); auto params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); params.nranks = static_cast(nranks); diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index a51edce83291..a320cb1fda1e 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -155,6 +155,7 @@ def forward(self, hidden_states: torch.Tensor): """ # input_dtype = hidden_states.dtype # hidden_states = hidden_states.to(torch.float32) + hidden_states = hidden_states.contiguous() rms_norm_out = self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) return rms_norm_out diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index 7b6f72e13553..e2ae40c8e785 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -842,6 +842,9 @@ def run_minimax_allreduce_rms_qk_op( ) assert not rank_q_input.is_contiguous() assert not rank_k_input.is_contiguous() + # MiniMaxM2Attention materializes the split views before calling the custom op. + rank_q_input = rank_q_input.contiguous() + rank_k_input = rank_k_input.contiguous() # rms weights should be sliced by rank rms_weights_q = rms_weights_q.reshape(tensor_parallel_size, -1) From 87020f2c8d470651eea2bee16013e00f60ad9b7d Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Fri, 10 Apr 2026 12:33:08 +0800 Subject: [PATCH 18/20] fix: fix for code rabbit review comments Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 15 +++++++++++++++ .../MiniMaxReduceRMSKernel.h | 15 +++++++++++++++ cpp/tensorrt_llm/thop/allreduceOp.cpp | 7 ++++++- tensorrt_llm/_torch/distributed/__init__.py | 7 ++++--- .../_torch/multi_gpu/test_allreduce.py | 18 ++++++++++-------- 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 4f51ad6267ca..2affc247ac1d 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h index 1b0ddd7feef6..b0cfd0ca074c 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include "tensorrt_llm/common/assert.h" #include diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 686d6da7441d..40d74c0a6fc3 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -1833,6 +1833,7 @@ torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor co input.size(-1) == norm_weight.size(0), "minimax_allreduce_rms: input hidden dim must match norm_weight"); TORCH_CHECK(input.is_contiguous(), "minimax_allreduce_rms: input must be contiguous"); TORCH_CHECK(norm_weight.is_contiguous(), "minimax_allreduce_rms: norm_weight must be contiguous"); + TORCH_CHECK(norm_weight.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms: norm_weight must be bfloat16"); auto allreduce_params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); @@ -1873,6 +1874,10 @@ std::vector minimax_allreduce_rms_qk(torch::Tensor const& q, torc TORCH_CHECK(norm_weight_k.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_k must be 1D"); TORCH_CHECK(norm_weight_q.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_q must be contiguous"); TORCH_CHECK(norm_weight_k.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_k must be contiguous"); + TORCH_CHECK( + norm_weight_q.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_q must be bfloat16"); + TORCH_CHECK( + norm_weight_k.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_k must be bfloat16"); int64_t head_dim_q = q.size(-1); int64_t head_dim_k = k.size(-1); TORCH_CHECK(head_dim_q >= head_dim_k, "minimax_allreduce_rms_qk: head_dim_q must be >= head_dim_k"); diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 9f42f9dd676c..29564f81ed50 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -3,9 +3,10 @@ from .communicator import Distributed, MPIDist, TorchDist from .moe_alltoall import MoeAlltoAll from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, - HelixAllToAllNative, MiniMaxAllReduceRMS, MoEAllReduce, MoEAllReduceParams, - all_to_all_4d, all_to_all_5d, allgather, alltoall_helix, - cp_allgather, reducescatter, userbuffers_allreduce_finalize) + HelixAllToAllNative, MiniMaxAllReduceRMS, MoEAllReduce, + MoEAllReduceParams, all_to_all_4d, all_to_all_5d, allgather, + alltoall_helix, cp_allgather, reducescatter, + userbuffers_allreduce_finalize) __all__ = [ "all_to_all_4d", diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index e2ae40c8e785..4893c15f76a0 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -712,14 +712,14 @@ def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, input = input.cuda() rms_weights = rms_weights.cuda() - input = input.reshape(total_tokens, tensor_parallel_size, - -1).to(torch.float32) - rms_weights = rms_weights.reshape(tensor_parallel_size, - -1).to(torch.float32) - rank_input = input[:, tensor_parallel_rank, :].contiguous() - rank_rms_weights = rms_weights[tensor_parallel_rank, :].contiguous() + rank_input = input.reshape(total_tokens, tensor_parallel_size, + -1)[:, tensor_parallel_rank, :].contiguous() + rank_rms_weights = rms_weights.reshape( + tensor_parallel_size, -1)[tensor_parallel_rank, :].contiguous() + # firstly, calculate the reference output for each rank - ref_output = rms_norm(input, rms_weights, eps) + ref_output = rms_norm(input.to(torch.float32), + rms_weights.to(torch.float32), eps) ref_output = ref_output.reshape(total_tokens, tensor_parallel_size, -1).to(origin_dtype) ref_output = ref_output[:, tensor_parallel_rank, :] @@ -897,6 +897,8 @@ def run_minimax_allreduce_rms_qk_single_rank(tensor_parallel_size, return True +@pytest.mark.skipif(torch.cuda.device_count() != 4, + reason="Requires exactly 4 GPUs for this test") @pytest.mark.parametrize("non_contiguous_input", [False, True], ids=["contiguous", "split_qkv_view"]) @pytest.mark.parametrize("mpi_pool_executor", [4], indirect=True) From 76be6886de05fda1f4a4598708f7c75e5a7e9be4 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:28:24 +0800 Subject: [PATCH 19/20] fix: add fake op impl Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index f75ee36cd088..9cf817c63201 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -97,6 +97,16 @@ def _(residual, norm_weight, device_num_experts, scale_input, residual_out = torch.empty_like(residual) return [norm_out, residual_out] + @torch.library.register_fake("trtllm::minimax_allreduce_rms") + def _(input, norm_weight, workspace, rank, nranks, eps, + trigger_completion_at_end): + return torch.empty_like(input) + + @torch.library.register_fake("trtllm::minimax_allreduce_rms_qk") + def _(q, k, norm_weight_q, norm_weight_k, workspace, rank, nranks, eps, + trigger_completion_at_end): + return [torch.empty_like(q), torch.empty_like(k)] + @torch.library.register_fake("trtllm::allgather") def allgather(input, sizes, group): if sizes is None: From d912d1c7fa8d31a918ca74754b750f81c4b79776 Mon Sep 17 00:00:00 2001 From: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:31:12 +0800 Subject: [PATCH 20/20] fix: clean code by hyukn comments Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com> --- .../MiniMaxReduceRMSKernel.cu | 86 ++++++++----------- .../_torch/models/modeling_minimaxm2.py | 12 --- 2 files changed, 34 insertions(+), 64 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu index 2affc247ac1d..5be8b1c2ff78 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -27,9 +27,9 @@ namespace kernels::minimax_ar namespace { // anonymous namespace -template +constexpr int kMinimaxReduceRmsWarpSize = 32; -#define MINIMAX_REDUCE_RMS_WARP_SIZE 32 +template struct LamportComm { __device__ __forceinline__ LamportComm(void** workspace, int rank) @@ -200,11 +200,11 @@ __device__ __forceinline__ void blockReduceSumRange(T* val, int rangeStart, int template __device__ __forceinline__ void local_warp_reduce_sum(T& value, uint32_t active_mask = 0xffffffffu) { - static_assert(kNumThreads >= 1 && kNumThreads <= MINIMAX_REDUCE_RMS_WARP_SIZE); + static_assert(kNumThreads >= 1 && kNumThreads <= kMinimaxReduceRmsWarpSize); #pragma unroll for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) { - value += __shfl_xor_sync(active_mask, value, mask, MINIMAX_REDUCE_RMS_WARP_SIZE); + value += __shfl_xor_sync(active_mask, value, mask, kMinimaxReduceRmsWarpSize); } } @@ -212,14 +212,14 @@ __device__ __forceinline__ void local_warp_reduce_sum(T& value, uint32_t active_ template __device__ __forceinline__ void local_warp_reduce_sum_array(T* value_ptr, uint32_t active_mask = 0xffffffffu) { - static_assert(kNumThreads >= 1 && kNumThreads <= MINIMAX_REDUCE_RMS_WARP_SIZE); + static_assert(kNumThreads >= 1 && kNumThreads <= kMinimaxReduceRmsWarpSize); #pragma unroll for (int i = 0; i < ArraySize; ++i) { #pragma unroll for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) { - value_ptr[i] += __shfl_xor_sync(active_mask, value_ptr[i], mask, MINIMAX_REDUCE_RMS_WARP_SIZE); + value_ptr[i] += __shfl_xor_sync(active_mask, value_ptr[i], mask, kMinimaxReduceRmsWarpSize); } } } @@ -403,8 +403,8 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo constexpr int RankKDim = OriginKDim / NRanks; constexpr int ThreadsPerRowQ = RankQDim / kElemsPerAccess; constexpr int ThreadsPerRowK = RankKDim / kElemsPerAccess; - constexpr int NumWarpQ = (ThreadsPerRowQ + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; - constexpr int NumWarpK = (ThreadsPerRowK + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) / MINIMAX_REDUCE_RMS_WARP_SIZE; + constexpr int NumWarpQ = (ThreadsPerRowQ + kMinimaxReduceRmsWarpSize - 1) / kMinimaxReduceRmsWarpSize; + constexpr int NumWarpK = (ThreadsPerRowK + kMinimaxReduceRmsWarpSize - 1) / kMinimaxReduceRmsWarpSize; int tot_tokens = params.size_q / RankQDim; int tot_groups = (tot_tokens + TokenPerBlock - 1) / TokenPerBlock; // ceiling: last group may have 1-3 valid rows @@ -422,13 +422,13 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo int access_id_in_token = threadIdx.x; int group_stride = gridDim.x; #endif - bool is_q = (access_id_in_token < NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE); + bool is_q = (access_id_in_token < NumWarpQ * kMinimaxReduceRmsWarpSize); int q_thread_idx = access_id_in_token; - int k_thread_idx = (access_id_in_token - (NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE)); + int k_thread_idx = (access_id_in_token - (NumWarpQ * kMinimaxReduceRmsWarpSize)); bool is_valid_token = is_q ? (access_id_in_token < ThreadsPerRowQ) : (k_thread_idx < ThreadsPerRowK); float4 clear_vec = get_neg_zero(); - __shared__ float block_reduce_sum[TokenPerBlock][MINIMAX_REDUCE_RMS_WARP_SIZE + 1]; // 33 > warpQ + warpK + __shared__ float block_reduce_sum[TokenPerBlock][kMinimaxReduceRmsWarpSize + 1]; // 33 > warpQ + warpK __shared__ float global_scale_q[TokenPerBlock]; __shared__ float global_scale_k[TokenPerBlock]; @@ -443,7 +443,7 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo // first step load rms params scale __nv_bfloat16 norm_weight[kElemsPerAccess]{}; - if (access_id_in_token < NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE) // Q branch + if (access_id_in_token < NumWarpQ * kMinimaxReduceRmsWarpSize) // Q branch { // load rms params scale if (is_valid_token) @@ -516,22 +516,22 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo // Local warp reduce: // here we use all threads to reduce warp_sum_variance - local_warp_reduce_sum_array(warp_sum_variance); + local_warp_reduce_sum_array(warp_sum_variance); // each warp write the warp reduce result to the shared memory - int line = threadIdx.x & (MINIMAX_REDUCE_RMS_WARP_SIZE - 1); + int line = threadIdx.x & (kMinimaxReduceRmsWarpSize - 1); if (line == 0) { #pragma unroll for (int _ = 0; _ < TokenPerBlock; ++_) { - block_reduce_sum[_][threadIdx.x / MINIMAX_REDUCE_RMS_WARP_SIZE] = warp_sum_variance[_]; + block_reduce_sum[_][threadIdx.x / kMinimaxReduceRmsWarpSize] = warp_sum_variance[_]; } } __syncthreads(); int tid = threadIdx.x; // then two warps process q block reduce and k block reduce respectively - if (tid < MINIMAX_REDUCE_RMS_WARP_SIZE) + if (tid < kMinimaxReduceRmsWarpSize) { constexpr int kNumWarpQPow2 = next_pow2(NumWarpQ) > NRanks ? next_pow2(NumWarpQ) : NRanks; float local_sum[TokenPerBlock]; @@ -578,8 +578,8 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo } } // k branch - else if (threadIdx.x >= MINIMAX_REDUCE_RMS_WARP_SIZE * NumWarpQ - && threadIdx.x < MINIMAX_REDUCE_RMS_WARP_SIZE * (NumWarpQ + 1)) + else if (threadIdx.x >= kMinimaxReduceRmsWarpSize * NumWarpQ + && threadIdx.x < kMinimaxReduceRmsWarpSize * (NumWarpQ + 1)) { constexpr int kNumWarpKPow2 = next_pow2(NumWarpK) > NRanks ? next_pow2(NumWarpK) : NRanks; float local_sum[TokenPerBlock]; @@ -700,16 +700,15 @@ __global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_flo int get_sm_count() { - static int sm_count = 0; - if (sm_count == 0) - { - int device_id; - TLLM_CUDA_CHECK(cudaGetDevice(&device_id)); - cudaDeviceProp device_prop; - cudaGetDeviceProperties(&device_prop, device_id); - sm_count = device_prop.multiProcessorCount; - } - return sm_count; + static int const smCount = []() + { + int deviceId; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); + cudaDeviceProp deviceProp; + TLLM_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, deviceId)); + return deviceProp.multiProcessorCount; + }(); + return smCount; } template @@ -777,9 +776,9 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par int cluster_num = tot_groups; int access_per_row_q = params.hidden_dim / kElemsPerAccess; int access_per_row_k = (params.allreduce_in_k != nullptr) ? (params.hidden_dim_k / kElemsPerAccess) : 0; - auto divUp = [](int a, int b) { return (a + b - 1) / b * b; }; // round up to the nearest multiple of b - int block_size = divUp(access_per_row_q, MINIMAX_REDUCE_RMS_WARP_SIZE) - + ((params.allreduce_in_k != nullptr) ? divUp(access_per_row_k, MINIMAX_REDUCE_RMS_WARP_SIZE) : 0); + auto const divUp = [](int a, int b) { return (a + b - 1) / b * b; }; // round up to the nearest multiple of b + int block_size = divUp(access_per_row_q, kMinimaxReduceRmsWarpSize) + + ((params.allreduce_in_k != nullptr) ? divUp(access_per_row_k, kMinimaxReduceRmsWarpSize) : 0); int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; cudaLaunchConfig_t cfg; @@ -799,32 +798,15 @@ void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& par cfg.numAttrs = SM >= 90 ? 2 : 0; bool trigger_completion_at_end = params.trigger_completion_at_end; - bool is_qk = (params.allreduce_in_k != nullptr); if (trigger_completion_at_end) { - if (is_qk) - { - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, - minimax_reduce_qk_rms_kernel_lamport_float4, params)); - } - else - { - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, - minimax_reduce_qk_rms_kernel_lamport_float4, params)); - } + TLLM_CUDA_CHECK(cudaLaunchKernelEx( + &cfg, minimax_reduce_qk_rms_kernel_lamport_float4, params)); } else { - if (is_qk) - { - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, - minimax_reduce_qk_rms_kernel_lamport_float4, params)); - } - else - { - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, - minimax_reduce_qk_rms_kernel_lamport_float4, params)); - } + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); } } diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index a320cb1fda1e..944f20ec77f3 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -143,18 +143,6 @@ def load_weights(self, weights: List[Dict]): copy_weight(self.weight, weight) def forward(self, hidden_states: torch.Tensor): - """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - - variance = hidden_states.pow(2).mean(-1, keepdim=True) / self.mapping.tp_size - variance = self.all_reduce(variance) - - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - hidden_states = self.weight * hidden_states.to(input_dtype) - """ - # input_dtype = hidden_states.dtype - # hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.contiguous() rms_norm_out = self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) return rms_norm_out