diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu new file mode 100644 index 000000000000..5be8b1c2ff78 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.cu @@ -0,0 +1,885 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include "tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h" +#include "tensorrt_llm/kernels/quantization.cuh" +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::minimax_ar +{ +namespace +{ // anonymous namespace + +constexpr int kMinimaxReduceRmsWarpSize = 32; + +template +struct LamportComm +{ + __device__ __forceinline__ LamportComm(void** workspace, int rank) + { + counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; + flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[2]; + clear_ptr = &reinterpret_cast(workspace[NRanks * 3 + 1])[0]; + flag_value = *flag_ptr; + auto comm_size = reinterpret_cast(workspace[NRanks * 3 + 1])[1]; + clear_size = *clear_ptr; + int data_offset = flag_value % 3; + int clear_offset = (flag_value + 2) % 3; + for (int r = 0; r < NRanks; ++r) + { + data_bufs[r] = reinterpret_cast(workspace[2 * NRanks + r]) + data_offset * comm_size; + } + clear_buf = reinterpret_cast(workspace[2 * NRanks + rank]) + clear_offset * comm_size; + __syncthreads(); + if (threadIdx.x == 0) + { + atomicAdd(counter_ptr, 1); + } + } + + __device__ __forceinline__ void update(int64_t new_clear_size) + { + if (blockIdx.x == 0 && threadIdx.x == 0) + { + while (*reinterpret_cast(counter_ptr) != gridDim.x) + { + } + *flag_ptr = (flag_value + 1) % 3; + *clear_ptr = new_clear_size; + *counter_ptr = 0; + } + } + + int* counter_ptr; + int* flag_ptr; + int64_t* clear_ptr; + uint8_t* data_bufs[NRanks]; + uint8_t* clear_buf; + int64_t clear_size; + int flag_value; +}; + +__device__ __forceinline__ bool is_neg_zero(float v) +{ + return *reinterpret_cast(&v) == 0x80000000; +} + +__device__ __forceinline__ bool is_neg_zero(float4 v) +{ + return is_neg_zero(v.x) || is_neg_zero(v.y) || is_neg_zero(v.z) || is_neg_zero(v.w); +} + +__device__ __forceinline__ float4 get_neg_zero() +{ + float4 vec; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + reinterpret_cast(&vec)[i] = 0x80000000; + } + return vec; +} + +template +__device__ __forceinline__ float rms_rsqrt(float& v, float eps) +{ + constexpr float kInvDim = 1.0F / static_cast(Dim); + v = rsqrtf((v * kInvDim) + eps); + return v; +} + +template +__device__ __forceinline__ float4 rms_rsqrt(float4& v, float eps) +{ + constexpr float kInvDim = 1.0F / static_cast(Dim); + v.x = rsqrtf((v.x * kInvDim) + eps); + v.y = rsqrtf((v.y * kInvDim) + eps); + v.z = rsqrtf((v.z * kInvDim) + eps); + v.w = rsqrtf((v.w * kInvDim) + eps); + return v; +} + +__device__ __forceinline__ float4 ld_global_volatile(float4* addr) +{ + float4 val; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w) + : "l"(addr)); + return val; +} + +__device__ __forceinline__ float ld_global_volatile(float* addr) +{ + float val; + asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(val) : "l"(addr)); + return val; +} + +template +__device__ __forceinline__ void blockReduceSumRange(T* val, int rangeStart, int rangeEnd) +{ + constexpr int kWarpSize = 32; + constexpr unsigned kFullMask = 0xffffffffu; + static __shared__ T shared[NUM][33]; + + int const activeThreadCount = max(rangeEnd - rangeStart, 0); + bool const isActive = threadIdx.x >= rangeStart && threadIdx.x < rangeEnd; + int const lane = threadIdx.x & (kWarpSize - 1); + unsigned const activeMask = __ballot_sync(kFullMask, isActive); + + if (isActive) + { +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + T sum = val[i]; +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) + { + sum += __shfl_down_sync(activeMask, sum, offset, kWarpSize); + } + val[i] = sum; + } + } + + if (isActive && lane == 0) + { + int const localWarpId = (threadIdx.x - rangeStart) >> 5; +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + shared[i][localWarpId] = val[i]; + } + } + + __syncthreads(); + + int const shiftedTid = threadIdx.x - rangeStart; + int const warpCount = (activeThreadCount + kWarpSize - 1) / kWarpSize; + bool const inLeaderWarp = shiftedTid >= 0 && shiftedTid < kWarpSize; + bool const leaderLaneIsValid = inLeaderWarp && shiftedTid < warpCount; + unsigned const leaderMask = __ballot_sync(kFullMask, leaderLaneIsValid); + + if (inLeaderWarp) + { +#pragma unroll + for (int i = 0; i < NUM; ++i) + { + T sum = leaderLaneIsValid ? shared[i][shiftedTid] : static_cast(0); +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) + { + sum += __shfl_down_sync(leaderMask, sum, offset, kWarpSize); + } + if (threadIdx.x == rangeStart) + { + val[i] = sum; + } + } + } +} + +// from sglang python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +template +__device__ __forceinline__ void local_warp_reduce_sum(T& value, uint32_t active_mask = 0xffffffffu) +{ + static_assert(kNumThreads >= 1 && kNumThreads <= kMinimaxReduceRmsWarpSize); +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) + { + value += __shfl_xor_sync(active_mask, value, mask, kMinimaxReduceRmsWarpSize); + } +} + +// for float4 version +template +__device__ __forceinline__ void local_warp_reduce_sum_array(T* value_ptr, uint32_t active_mask = 0xffffffffu) +{ + static_assert(kNumThreads >= 1 && kNumThreads <= kMinimaxReduceRmsWarpSize); +#pragma unroll + for (int i = 0; i < ArraySize; ++i) + { +#pragma unroll + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) + { + value_ptr[i] += __shfl_xor_sync(active_mask, value_ptr[i], mask, kMinimaxReduceRmsWarpSize); + } + } +} + +constexpr int next_pow2(int val) +{ + int result = 1; + while (result < val) + { + result <<= 1; + } + return result; +} + +template +class IndexHelper +{ +public: + __device__ __forceinline__ IndexHelper(MiniMaxReduceRMSParams const& params) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + cg::grid_group grid = cg::this_grid(); + token_id = grid.cluster_rank(); + access_id_in_token = cluster.thread_rank(); + token_stride = grid.num_clusters(); +#else + token_id = blockIdx.x; + access_id_in_token = threadIdx.x; + token_stride = gridDim.x; +#endif + access_id = token_id * params.hidden_dim / kElemsPerAccess + access_id_in_token; + access_stride = token_stride * params.hidden_dim / kElemsPerAccess; + tot_access = params.size_q / kElemsPerAccess; + } + + int token_id; + int access_id_in_token; + int token_stride; + int access_id; + int access_stride; + int tot_access; +}; + +/** +* this kernel is used to for minimax attention module +* input tensor [total_tokens, hidden_dim / tp_size], fp32 +* rms weight [hidden_dim / tp_size], bf16 +step 1: reduce from single rank to get the variance sum (reduce(input^2, dim=-1)) +step 2: reduce from all ranks to get the variance sum (all_reduce(variance_sum)) +step 3: calculate the rms norm (input * rsqrt(variance + eps)) +in this case, max hidden_dim is 6144 (float data), for each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) +threads so we can assume cluster size is 1 (tp_size >= 2) + */ +template +__global__ void __launch_bounds__(1024) minimax_reduce_rms_kernel_lamport(MiniMaxReduceRMSParams params) +{ + IndexHelper index_helper(params); + int token_id = index_helper.token_id; + int access_id_in_token = index_helper.access_id_in_token; + int token_stride = index_helper.token_stride; + int access_id = index_helper.access_id; + int access_stride = index_helper.access_stride; + int tot_access = index_helper.tot_access; + int tot_tokens = params.size_q / params.hidden_dim; + float4 clear_vec = get_neg_zero(); + // FusedOp fused_op(params, access_id, access_id_in_token); + __shared__ float shared_vars_all_ranks; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + if constexpr (!TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + LamportComm comm(params.workspace, params.rank); + int clear_access = comm.clear_size / kElemsPerAccess; + for (int idx = access_id; idx < tot_access; idx += access_stride, token_id += token_stride) + { + alignas(16) DType vals[kElemsPerAccess]; + // we use float to load and store variance sum + float sum_variance = 0.F; + *reinterpret_cast(vals) = reinterpret_cast(params.allreduce_in)[idx]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + sum_variance += static_cast(vals[i]) * static_cast(vals[i]); + } + // step 1: reduce from single rank to get the variance sum + tensorrt_llm::common::blockReduceSumV2(&sum_variance); + if (is_neg_zero(sum_variance)) + { + sum_variance = 0.F; + } + // step 2: reduce from all ranks to get the variance sum + // be careful, we only use float to load and store variance sum + // but we use float4 to load input tensor + // Push data to other ranks + // we only need the first thread to push data to other ranks + if (threadIdx.x == 0) + { +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + // temp data buffer [nranks, total_tokens, 1] + reinterpret_cast(comm.data_bufs[r])[(params.rank * tot_tokens) + token_id] = (sum_variance); + } + // we only use the first thread to pull data from other ranks + bool done = false; + float vals_all_ranks[NRanks]; + while (!done) + { + done = true; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + vals_all_ranks[r] = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(r * tot_tokens) + token_id]); + done &= !is_neg_zero(vals_all_ranks[r]); + } + } + + sum_variance = 0.F; +#pragma unroll + for (int r = 0; r < NRanks; ++r) + { + sum_variance += vals_all_ranks[r]; + } + sum_variance = rsqrtf(sum_variance / NRanks / static_cast(params.hidden_dim) + params.rms_eps); + shared_vars_all_ranks = sum_variance; + } + + __syncthreads(); + sum_variance = shared_vars_all_ranks; + + // step 3: calculate the rms norm (input * rsqrt(variance + eps)) + + // load norm weight + // TODO: correct the access_id_in_token + __nv_bfloat16 norm_weight[kElemsPerAccess]; + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type*>(params.rms_gamma)[access_id_in_token]; + +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[i] + = static_cast(static_cast(vals[i]) * sum_variance * static_cast(norm_weight[i])); + } + + // step 4: store the rms norm + reinterpret_cast(params.rms_norm_out)[idx] = *reinterpret_cast(vals); + } + for (int idx = access_id; idx < clear_access; idx += access_stride) + { + // Clear comm buffer that previous kernel used + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; + } + comm.update(tot_tokens * NRanks * sizeof(float) / sizeof(DType)); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if constexpr (TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + +/** + * Float4 variant: process 4 rows at once, allreduce variance sums as float4 for better memory coalescing. + * sum_variance is always float; applies to all DTypes (half, bf16, float). + * When tot_tokens % 4 != 0, the last group pads rows with zeros; padded rows are not written to rms_norm_out. + * IsQK: when true, process Q+K in one loop with doubled comm buffer; when false, single-matrix (Q only). + */ +template +__global__ void __launch_bounds__(1024) minimax_reduce_qk_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) +{ + static_assert(TokenPerBlock == 1 || TokenPerBlock == 4, "TokenPerBlock must be 1 or 4"); + constexpr int RankQDim = OriginQDim / NRanks; + constexpr int RankKDim = OriginKDim / NRanks; + constexpr int ThreadsPerRowQ = RankQDim / kElemsPerAccess; + constexpr int ThreadsPerRowK = RankKDim / kElemsPerAccess; + constexpr int NumWarpQ = (ThreadsPerRowQ + kMinimaxReduceRmsWarpSize - 1) / kMinimaxReduceRmsWarpSize; + constexpr int NumWarpK = (ThreadsPerRowK + kMinimaxReduceRmsWarpSize - 1) / kMinimaxReduceRmsWarpSize; + int tot_tokens = params.size_q / RankQDim; + int tot_groups = (tot_tokens + TokenPerBlock - 1) / TokenPerBlock; // ceiling: last group may have 1-3 valid rows + + using AccumType = std::conditional_t; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + cg::grid_group grid = cg::this_grid(); + int group_id = grid.cluster_rank(); + int access_id_in_token = cluster.thread_rank(); + int group_stride = grid.num_clusters(); +#else + int group_id = blockIdx.x; + int access_id_in_token = threadIdx.x; + int group_stride = gridDim.x; +#endif + bool is_q = (access_id_in_token < NumWarpQ * kMinimaxReduceRmsWarpSize); + int q_thread_idx = access_id_in_token; + int k_thread_idx = (access_id_in_token - (NumWarpQ * kMinimaxReduceRmsWarpSize)); + bool is_valid_token = is_q ? (access_id_in_token < ThreadsPerRowQ) : (k_thread_idx < ThreadsPerRowK); + float4 clear_vec = get_neg_zero(); + + __shared__ float block_reduce_sum[TokenPerBlock][kMinimaxReduceRmsWarpSize + 1]; // 33 > warpQ + warpK + __shared__ float global_scale_q[TokenPerBlock]; + __shared__ float global_scale_k[TokenPerBlock]; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); + if constexpr (!TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + LamportComm comm(params.workspace, params.rank); + + // first step load rms params scale + __nv_bfloat16 norm_weight[kElemsPerAccess]{}; + if (access_id_in_token < NumWarpQ * kMinimaxReduceRmsWarpSize) // Q branch + { + // load rms params scale + if (is_valid_token) + { + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma)[access_id_in_token]; + } + } + else // K branch + { + // load rms params scale + if (is_valid_token) + { + *reinterpret_cast::norm_weight_type*>(norm_weight) + = reinterpret_cast::norm_weight_type const*>( + params.rms_gamma_k)[k_thread_idx]; + } + } + + for (int g = group_id; g < tot_groups; g += group_stride) + { + alignas(16) DType vals[TokenPerBlock][kElemsPerAccess]{}; + float warp_sum_variance[TokenPerBlock]{0.F}; + + if (is_q) + { + // Q branch: each thread only covers 128bit * TokenPerBlock +#pragma unroll + for (int row = 0; row < TokenPerBlock; ++row) + { + int token_r = (g * TokenPerBlock) + row; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; + } + int idx_r = (token_r * ThreadsPerRowQ) + access_id_in_token; + *reinterpret_cast(&vals[row][0]) = reinterpret_cast(params.allreduce_in)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + auto x = static_cast(vals[row][i]); + warp_sum_variance[row] += x * x; + } + } + } + else // k branch + { +// K branch: k_thread_idx = threadIdx.x - q_warps, each thread covers 32 K columns +#pragma unroll + for (int row = 0; row < TokenPerBlock; ++row) + { + int token_r = (g * TokenPerBlock) + row; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; + } + + int idx_r = (token_r * ThreadsPerRowK) + k_thread_idx; + *reinterpret_cast(&vals[row][0]) + = reinterpret_cast(params.allreduce_in_k)[idx_r]; +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + auto x = static_cast(vals[row][i]); + warp_sum_variance[row] += x * x; + } + } + } + + // Local warp reduce: + // here we use all threads to reduce warp_sum_variance + local_warp_reduce_sum_array(warp_sum_variance); + // each warp write the warp reduce result to the shared memory + int line = threadIdx.x & (kMinimaxReduceRmsWarpSize - 1); + if (line == 0) + { +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + block_reduce_sum[_][threadIdx.x / kMinimaxReduceRmsWarpSize] = warp_sum_variance[_]; + } + } + __syncthreads(); + int tid = threadIdx.x; + // then two warps process q block reduce and k block reduce respectively + + if (tid < kMinimaxReduceRmsWarpSize) + { + constexpr int kNumWarpQPow2 = next_pow2(NumWarpQ) > NRanks ? next_pow2(NumWarpQ) : NRanks; + float local_sum[TokenPerBlock]; +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + local_sum[_] = tid < NumWarpQ ? block_reduce_sum[_][tid] : 0.F; + } + local_warp_reduce_sum_array(local_sum); + // for thread [0, NRanks), we need to push data to comm buffer + if (tid < NRanks) + { +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + if (is_neg_zero(local_sum[_])) + { + local_sum[_] = 0.F; + } + } + // push data to comm buffer, for each thread, we only need to push data to one rank + + reinterpret_cast(comm.data_bufs[tid])[(params.rank * tot_groups * 2) + (2 * g)] + = *reinterpret_cast(local_sum); + // pull data from other ranks + bool done = false; + AccumType var_all_ranks; + while (!done) + { + done = true; + var_all_ranks = ld_global_volatile( + &reinterpret_cast(comm.data_bufs[params.rank])[(tid * tot_groups * 2) + (2 * g)]); + done &= !is_neg_zero(var_all_ranks); + } + // local reduce + constexpr uint32_t kActiveMask = (1 << NRanks) - 1; + local_warp_reduce_sum_array( + reinterpret_cast(&var_all_ranks), kActiveMask); + if (tid == 0) + { + *reinterpret_cast(global_scale_q) + = rms_rsqrt(var_all_ranks, params.rms_eps); + } + } + } + // k branch + else if (threadIdx.x >= kMinimaxReduceRmsWarpSize * NumWarpQ + && threadIdx.x < kMinimaxReduceRmsWarpSize * (NumWarpQ + 1)) + { + constexpr int kNumWarpKPow2 = next_pow2(NumWarpK) > NRanks ? next_pow2(NumWarpK) : NRanks; + float local_sum[TokenPerBlock]; +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + local_sum[_] = k_thread_idx < NumWarpK ? block_reduce_sum[_][NumWarpQ + k_thread_idx] : 0.F; + } + local_warp_reduce_sum_array(local_sum); + // for thread [0, NRanks), we need to push data to comm buffer + if (k_thread_idx < NRanks) + { +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + if (is_neg_zero(local_sum[_])) + { + local_sum[_] = 0.F; + } + } + // push data to comm buffer, for each thread, we only need to push data to one rank + reinterpret_cast(comm.data_bufs[k_thread_idx])[(params.rank * tot_groups * 2) + (2 * g + 1)] + = *reinterpret_cast(local_sum); + // pull data from other ranks + bool done = false; + AccumType var_all_ranks; + while (!done) + { + done = true; + var_all_ranks = ld_global_volatile(&reinterpret_cast( + comm.data_bufs[params.rank])[(k_thread_idx * tot_groups * 2) + (2 * g + 1)]); + done &= !is_neg_zero(var_all_ranks); + } + // local reduce + constexpr uint32_t kActiveMask = (1 << NRanks) - 1; + local_warp_reduce_sum_array( + reinterpret_cast(&var_all_ranks), kActiveMask); + if (k_thread_idx == 0) + { + *reinterpret_cast(global_scale_k) + = rms_rsqrt(var_all_ranks, params.rms_eps); + } + } + } + __syncthreads(); + // final part + if (is_q) + { +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + warp_sum_variance[_] = global_scale_q[_]; + } +#pragma unroll + for (int r = 0; r < TokenPerBlock; ++r) + { +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] = static_cast( + static_cast(vals[r][i]) * warp_sum_variance[r] * static_cast(norm_weight[i])); + } + // store to rms_norm_out + int token_r = (g * TokenPerBlock) + r; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; + } + int idx_r = (token_r * ThreadsPerRowQ) + access_id_in_token; + reinterpret_cast(params.rms_norm_out)[idx_r] = *reinterpret_cast(&vals[r][0]); + } + } + else + { +#pragma unroll + for (int _ = 0; _ < TokenPerBlock; ++_) + { + warp_sum_variance[_] = global_scale_k[_]; + } +#pragma unroll + for (int r = 0; r < TokenPerBlock; ++r) + { +#pragma unroll + for (int i = 0; i < kElemsPerAccess; ++i) + { + vals[r][i] = static_cast( + static_cast(vals[r][i]) * warp_sum_variance[r] * static_cast(norm_weight[i])); + } + // store to rms_norm_out + int token_r = (g * TokenPerBlock) + r; + if (token_r >= tot_tokens || (!is_valid_token)) + { + continue; + } + int idx_r = (token_r * ThreadsPerRowK) + k_thread_idx; + reinterpret_cast(params.rms_norm_out_k)[idx_r] = *reinterpret_cast(&vals[r][0]); + } + } + } + + // Clear comm buffer + int clear_access = static_cast(comm.clear_size / (sizeof(float4) / sizeof(DType))); + int clear_stride = group_stride * blockDim.x; + + for (int idx = group_id * blockDim.x + threadIdx.x; idx < clear_access; idx += clear_stride) + { + reinterpret_cast(comm.clear_buf)[idx] = clear_vec; + } + + comm.update((2 * tot_groups * TokenPerBlock * sizeof(float) / sizeof(DType) * NRanks)); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if constexpr (TriggerCompletionAtEnd) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif +} + +int get_sm_count() +{ + static int const smCount = []() + { + int deviceId; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceId)); + cudaDeviceProp deviceProp; + TLLM_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, deviceId)); + return deviceProp.multiProcessorCount; + }(); + return smCount; +} + +template +void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) +{ + TLLM_CHECK(params.size_q % params.hidden_dim == 0); + TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); + static int SM = tensorrt_llm::common::getSMVersion(); + int token_num = params.size_q / params.hidden_dim; + // for current problem size, we only need one cluster + int sm_count = get_sm_count(); + int cluster_size = 1; + int cluster_num = token_num; + int threads_per_token = params.hidden_dim / kElemsPerAccess; + int block_size = threads_per_token; + int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; + + cudaLaunchConfig_t cfg; + cfg.gridDim = grid_size; + cfg.blockDim = block_size; + cfg.dynamicSmemBytes = 0; + cfg.stream = params.stream; + + cudaLaunchAttribute attribute[2]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0; + attribute[1].id = cudaLaunchAttributeClusterDimension; + attribute[1].val.clusterDim.x = cluster_size; + attribute[1].val.clusterDim.y = 1; + attribute[1].val.clusterDim.z = 1; + cfg.attrs = attribute; + cfg.numAttrs = SM >= 90 ? 2 : 0; + bool trigger_completion_at_end = params.trigger_completion_at_end; + if (trigger_completion_at_end) + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport, params)); + } + else + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, minimax_reduce_rms_kernel_lamport, params)); + } +} + +template +void minimax_reduce_rms_kernel_launcher_float4(MiniMaxReduceRMSParams const& params) +{ + TLLM_CHECK(params.size_q % params.hidden_dim == 0); + TLLM_CHECK(params.hidden_dim % kElemsPerAccess == 0); + if (params.allreduce_in_k != nullptr) + { + TLLM_CHECK(params.hidden_dim >= params.hidden_dim_k); + TLLM_CHECK(params.size_k % params.hidden_dim_k == 0); + TLLM_CHECK(params.hidden_dim_k % kElemsPerAccess == 0); + TLLM_CHECK(params.size_q / params.hidden_dim == params.size_k / params.hidden_dim_k); + } + int token_num = params.size_q / params.hidden_dim; + int tot_groups = (token_num + 3) / 4; // ceiling + if (tot_groups == 0) + { + return; + } + static int SM = tensorrt_llm::common::getSMVersion(); + int sm_count = get_sm_count(); + int cluster_size = 1; + int cluster_num = tot_groups; + int access_per_row_q = params.hidden_dim / kElemsPerAccess; + int access_per_row_k = (params.allreduce_in_k != nullptr) ? (params.hidden_dim_k / kElemsPerAccess) : 0; + auto const divUp = [](int a, int b) { return (a + b - 1) / b * b; }; // round up to the nearest multiple of b + int block_size = divUp(access_per_row_q, kMinimaxReduceRmsWarpSize) + + ((params.allreduce_in_k != nullptr) ? divUp(access_per_row_k, kMinimaxReduceRmsWarpSize) : 0); + int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; + + cudaLaunchConfig_t cfg; + cfg.gridDim = grid_size; + cfg.blockDim = block_size; + cfg.dynamicSmemBytes = 0; + cfg.stream = params.stream; + + cudaLaunchAttribute attribute[2]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0; + attribute[1].id = cudaLaunchAttributeClusterDimension; + attribute[1].val.clusterDim.x = cluster_size; + attribute[1].val.clusterDim.y = 1; + attribute[1].val.clusterDim.z = 1; + cfg.attrs = attribute; + cfg.numAttrs = SM >= 90 ? 2 : 0; + + bool trigger_completion_at_end = params.trigger_completion_at_end; + if (trigger_completion_at_end) + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx( + &cfg, minimax_reduce_qk_rms_kernel_lamport_float4, params)); + } + else + { + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, + minimax_reduce_qk_rms_kernel_lamport_float4, params)); + } +} + +template +void dispatch_dtype(MiniMaxReduceRMSParams const& params) +{ + bool use_float4 = (params.allreduce_in_k != nullptr) && (params.hidden_dim * params.nranks == 6144) + && (params.hidden_dim_k * params.nranks == 1024); + + if (params.dtype == nvinfer1::DataType::kHALF) + { + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4(params); + } + else + { + minimax_reduce_rms_kernel_launcher(params); + } + } + else if (params.dtype == nvinfer1::DataType::kBF16) + { + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144, 1024>(params); + } + else + { + minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params); + } + } + else if (params.dtype == nvinfer1::DataType::kFLOAT) + { + if (use_float4) + { + minimax_reduce_rms_kernel_launcher_float4(params); + } + else + { + minimax_reduce_rms_kernel_launcher(params); + } + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported data type for minimax_reduce_rms_op"); + } +} +} // namespace + +void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) +{ + if (params.nranks == 2) + { + dispatch_dtype<2>(params); + } + else if (params.nranks == 4) + { + dispatch_dtype<4>(params); + } + else if (params.nranks == 8) + { + dispatch_dtype<8>(params); + } + else if (params.nranks == 16) + { + dispatch_dtype<16>(params); + } + else + { + TLLM_CHECK_WITH_INFO(false, "minimax_reduce_rms_op: unsupported ranks number!"); + } +} + +} // namespace kernels::minimax_ar + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h new file mode 100644 index 000000000000..b0cfd0ca074c --- /dev/null +++ b/cpp/tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "tensorrt_llm/common/assert.h" +#include +#include +#include + +#include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/quantization.h" +#include "tensorrt_llm/runtime/ipcUtils.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::minimax_ar +{ +template +struct ElemsPerAccess; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 8; + using norm_weight_type = common::__nv_bfloat168; +}; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 8; + using norm_weight_type = common::__nv_bfloat168; +}; + +template <> +struct ElemsPerAccess +{ + static constexpr int value = 4; + using norm_weight_type = common::__nv_bfloat164; +}; + +template +static constexpr int kElemsPerAccess = ElemsPerAccess::value; + +struct MiniMaxReduceRMSParams +{ + int nranks{}; + int rank{}; + nvinfer1::DataType dtype; + int size_q{}; // numel of Q (num_token * head_dim_q) + int hidden_dim{}; // head_dim_q + int size_k{}; // numel of K (num_token * head_dim_k) + int hidden_dim_k{}; // head_dim_k; must have head_dim_q >= head_dim_k + void** workspace{}; + void* allreduce_in{}; // Q input + void* rms_norm_out{}; // Q output + void* rms_gamma{}; // Q norm weight + void* allreduce_in_k{}; // K input (nullptr for single-matrix path) + void* rms_norm_out_k{}; // K output + void* rms_gamma_k{}; // K norm weight + float rms_eps{}; + cudaStream_t stream{}; + bool trigger_completion_at_end = true; +}; + +void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params); + +} // namespace kernels::minimax_ar + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index e5d875fb52e5..40d74c0a6fc3 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +23,7 @@ #include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/common/nvmlWrapper.h" #include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/MiniMaxReduceRMSKernel.h" #include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" #include "tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h" @@ -1822,6 +1823,96 @@ std::vector mnnvlFusionAllReduce(torch::Tensor& input, torch::opt return {output, residualOut}; } +torch::Tensor minimax_allreduce_rms(torch::Tensor const& input, torch::Tensor const& norm_weight, + torch::Tensor workspace, int64_t const rank, int64_t const nranks, double const eps, + bool const trigger_completion_at_end_) +{ + TORCH_CHECK(input.dim() == 2, "minimax_allreduce_rms: input must be 2D"); + TORCH_CHECK(norm_weight.dim() == 1, "minimax_allreduce_rms: norm_weight must be 1D"); + TORCH_CHECK( + input.size(-1) == norm_weight.size(0), "minimax_allreduce_rms: input hidden dim must match norm_weight"); + TORCH_CHECK(input.is_contiguous(), "minimax_allreduce_rms: input must be contiguous"); + TORCH_CHECK(norm_weight.is_contiguous(), "minimax_allreduce_rms: norm_weight must be contiguous"); + TORCH_CHECK(norm_weight.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms: norm_weight must be bfloat16"); + + auto allreduce_params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); + + allreduce_params.nranks = static_cast(nranks); + allreduce_params.rank = static_cast(rank); + allreduce_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); + allreduce_params.size_q = static_cast(input.numel()); + allreduce_params.hidden_dim = static_cast(input.size(-1)); + allreduce_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); + allreduce_params.allreduce_in = input.data_ptr(); + // allreduce_params.rms_norm_out = nullptr; + allreduce_params.rms_gamma = norm_weight.data_ptr(); + allreduce_params.rms_eps = static_cast(eps); + allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + torch::Tensor rms_norm_out = torch::empty_like(input); + allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr(); + allreduce_params.trigger_completion_at_end = trigger_completion_at_end_; + + tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(allreduce_params); + + return rms_norm_out; +} + +std::vector minimax_allreduce_rms_qk(torch::Tensor const& q, torch::Tensor const& k, + torch::Tensor const& norm_weight_q, torch::Tensor const& norm_weight_k, torch::Tensor workspace, int64_t const rank, + int64_t const nranks, double const eps, bool const trigger_completion_at_end_) +{ + int64_t constexpr kSupportedGlobalHeadDimQ = 6144; + int64_t constexpr kSupportedGlobalHeadDimK = 1024; + + TORCH_CHECK(q.scalar_type() == k.scalar_type(), "minimax_allreduce_rms_qk: q and k must have same dtype"); + TORCH_CHECK(q.dim() == 2 && k.dim() == 2, "minimax_allreduce_rms_qk: q and k must be 2D"); + TORCH_CHECK(q.size(0) == k.size(0), "minimax_allreduce_rms_qk: q and k must have same num_token"); + TORCH_CHECK(q.is_contiguous(), "minimax_allreduce_rms_qk: q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "minimax_allreduce_rms_qk: k must be contiguous"); + TORCH_CHECK(norm_weight_q.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_q must be 1D"); + TORCH_CHECK(norm_weight_k.dim() == 1, "minimax_allreduce_rms_qk: norm_weight_k must be 1D"); + TORCH_CHECK(norm_weight_q.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_q must be contiguous"); + TORCH_CHECK(norm_weight_k.is_contiguous(), "minimax_allreduce_rms_qk: norm_weight_k must be contiguous"); + TORCH_CHECK( + norm_weight_q.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_q must be bfloat16"); + TORCH_CHECK( + norm_weight_k.scalar_type() == torch::kBFloat16, "minimax_allreduce_rms_qk: norm_weight_k must be bfloat16"); + int64_t head_dim_q = q.size(-1); + int64_t head_dim_k = k.size(-1); + TORCH_CHECK(head_dim_q >= head_dim_k, "minimax_allreduce_rms_qk: head_dim_q must be >= head_dim_k"); + TORCH_CHECK(head_dim_q == norm_weight_q.size(0), "minimax_allreduce_rms_qk: q hidden dim must match norm_weight_q"); + TORCH_CHECK(head_dim_k == norm_weight_k.size(0), "minimax_allreduce_rms_qk: k hidden dim must match norm_weight_k"); + TORCH_CHECK((head_dim_q * nranks) == kSupportedGlobalHeadDimQ && (head_dim_k * nranks) == kSupportedGlobalHeadDimK, + "minimax_allreduce_rms_qk: only global q/k dims 6144/1024 are currently supported"); + + auto params = tensorrt_llm::kernels::minimax_ar::MiniMaxReduceRMSParams(); + params.nranks = static_cast(nranks); + params.rank = static_cast(rank); + params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(q.scalar_type()); + params.size_q = static_cast(q.numel()); + params.hidden_dim = static_cast(head_dim_q); + params.size_k = static_cast(k.numel()); + params.hidden_dim_k = static_cast(head_dim_k); + params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); + params.allreduce_in = q.data_ptr(); + params.rms_gamma = norm_weight_q.data_ptr(); + params.allreduce_in_k = k.data_ptr(); + params.rms_gamma_k = norm_weight_k.data_ptr(); + params.rms_eps = static_cast(eps); + params.stream = at::cuda::getCurrentCUDAStream(q.get_device()); + params.trigger_completion_at_end = trigger_completion_at_end_; + + torch::Tensor rms_norm_out_q = torch::empty_like(q); + torch::Tensor rms_norm_out_k = torch::empty_like(k); + params.rms_norm_out = rms_norm_out_q.mutable_data_ptr(); + params.rms_norm_out_k = rms_norm_out_k.mutable_data_ptr(); + + tensorrt_llm::kernels::minimax_ar::minimax_reduce_rms_op(params); + + return {rms_norm_out_q, rms_norm_out_k}; +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -1886,6 +1977,26 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "int nranks," "float eps) -> Tensor[]"); m.def("preallocate_nccl_window_buffer(Tensor input, int[] group, int count) -> ()"); + m.def( + "minimax_allreduce_rms(" + "Tensor input," + "Tensor norm_weight," + "Tensor workspace," + "int rank," + "int nranks," + "float eps," + "bool trigger_completion_at_end) -> Tensor"); + m.def( + "minimax_allreduce_rms_qk(" + "Tensor q," + "Tensor k," + "Tensor norm_weight_q," + "Tensor norm_weight_k," + "Tensor workspace," + "int rank," + "int nranks," + "float eps," + "bool trigger_completion_at_end) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) @@ -1896,6 +2007,8 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_allreduce", &tensorrt_llm::torch_ext::moe_allreduce); m.impl("moe_finalize_allreduce", &tensorrt_llm::torch_ext::moe_finalize_allreduce); m.impl("preallocate_nccl_window_buffer", &tensorrt_llm::torch_ext::preallocateNCCLWindowBuffer); + m.impl("minimax_allreduce_rms", &tensorrt_llm::torch_ext::minimax_allreduce_rms); + m.impl("minimax_allreduce_rms_qk", &tensorrt_llm::torch_ext::minimax_allreduce_rms_qk); } TORCH_LIBRARY_IMPL(trtllm, CPU, m) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index f75ee36cd088..9cf817c63201 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -97,6 +97,16 @@ def _(residual, norm_weight, device_num_experts, scale_input, residual_out = torch.empty_like(residual) return [norm_out, residual_out] + @torch.library.register_fake("trtllm::minimax_allreduce_rms") + def _(input, norm_weight, workspace, rank, nranks, eps, + trigger_completion_at_end): + return torch.empty_like(input) + + @torch.library.register_fake("trtllm::minimax_allreduce_rms_qk") + def _(q, k, norm_weight_q, norm_weight_k, workspace, rank, nranks, eps, + trigger_completion_at_end): + return [torch.empty_like(q), torch.empty_like(k)] + @torch.library.register_fake("trtllm::allgather") def allgather(input, sizes, group): if sizes is None: diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 5e18d0d7b77a..29564f81ed50 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -3,9 +3,10 @@ from .communicator import Distributed, MPIDist, TorchDist from .moe_alltoall import MoeAlltoAll from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, - HelixAllToAllNative, MoEAllReduce, MoEAllReduceParams, - all_to_all_4d, all_to_all_5d, allgather, alltoall_helix, - cp_allgather, reducescatter, userbuffers_allreduce_finalize) + HelixAllToAllNative, MiniMaxAllReduceRMS, MoEAllReduce, + MoEAllReduceParams, all_to_all_4d, all_to_all_5d, allgather, + alltoall_helix, cp_allgather, reducescatter, + userbuffers_allreduce_finalize) __all__ = [ "all_to_all_4d", @@ -22,6 +23,7 @@ "HelixAllToAllNative", "MoEAllReduce", "MoEAllReduceParams", + "MiniMaxAllReduceRMS", "MoeAlltoAll", "TorchDist", "MPIDist", diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index dd71bb00ec9e..892a916f7fb5 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1178,3 +1178,36 @@ def all_to_all_5d( gathered_heads = heads * world_size return out.reshape(batch, sharded_seq, qkv_count, gathered_heads, head_dim) + + +class MiniMaxAllReduceRMS(nn.Module): + + def __init__(self, mapping: Mapping): + super().__init__() + self.mapping = mapping + self.workspace = get_allreduce_workspace(self.mapping) + + def forward(self, input: torch.Tensor, rms_weights: torch.Tensor, + eps: float): + return torch.ops.trtllm.minimax_allreduce_rms(input, rms_weights, + self.workspace, + self.mapping.tp_rank, + self.mapping.tp_size, eps, + True) + + def forward_qk(self, q: torch.Tensor, k: torch.Tensor, + rms_weights_q: torch.Tensor, rms_weights_k: torch.Tensor, + eps: float): + """Fused Q+K RMS norm with allreduce. Returns (q_out, k_out).""" + out_list = torch.ops.trtllm.minimax_allreduce_rms_qk( + q, + k, + rms_weights_q, + rms_weights_k, + self.workspace, + self.mapping.tp_rank, + self.mapping.tp_size, + eps, + True, + ) + return (out_list[0], out_list[1]) diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm2.py b/tensorrt_llm/_torch/models/modeling_minimaxm2.py index 8e6b498e9d4f..944f20ec77f3 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm2.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm2.py @@ -19,16 +19,18 @@ from torch import nn from transformers import PretrainedConfig -from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType +from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..distributed import AllReduce, MiniMaxAllReduceRMS from ..models.modeling_utils import ModelConfig from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import MiniMaxM2MoeRoutingMethod, create_moe -from ..modules.linear import Linear +from ..modules.linear import Linear, TensorParallelMode, copy_weight, load_weight_shard from ..modules.rms_norm import RMSNorm from ..utils import AuxStreamType from .modeling_utils import DecoderModel, DecoderModelForCausalLM, register_auto_model @@ -114,6 +116,38 @@ def forward( return final_hidden_states +# We use all_reduce across all tp gpus to get the rms norm variance sum +class MiniMaxRMSNorm(nn.Module): + def __init__( + self, *, hidden_size: int, eps: float, mapping: Mapping, dtype: torch.dtype = torch.bfloat16 + ): + super().__init__() + self.mapping = mapping + # for attention input, tp_size * hidden_size = head_num * head_size + self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype), requires_grad=False) + self.hidden_size = hidden_size + self.eps = eps + self.dtype = dtype + self.all_reduce = AllReduce(mapping=self.mapping, strategy=AllReduceStrategy.NCCL) + + self.minimax_all_reduce_rms = MiniMaxAllReduceRMS(mapping=self.mapping) + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + weight = load_weight_shard( + weights[0]["weight"], + tensor_parallel_size=self.mapping.tp_size, + tensor_parallel_rank=self.mapping.tp_rank, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + copy_weight(self.weight, weight) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.contiguous() + rms_norm_out = self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) + return rms_norm_out + + # It's a little bit tricky to implement special qk norm # because rms dim is hidden_size * num_heads, not hidden_size, after qkv linear, # the result size is hidden_size * num_heads / tp_size. @@ -148,37 +182,41 @@ def __init__( dtype=config.torch_dtype, config=model_config, ) - - self.q_norm = RMSNorm( - hidden_size=self.q_size * self.tp_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype, - ) - self.k_norm = RMSNorm( - hidden_size=self.kv_size * self.tp_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype, - ) + if self.qkv_proj.mapping.tp_size > 1: + self.q_norm = MiniMaxRMSNorm( + hidden_size=self.q_size, + eps=config.rms_norm_eps, + mapping=self.qkv_proj.mapping, + dtype=config.torch_dtype, + ) + self.k_norm = MiniMaxRMSNorm( + hidden_size=self.kv_size, + eps=config.rms_norm_eps, + mapping=self.qkv_proj.mapping, + dtype=config.torch_dtype, + ) + else: + self.q_norm = RMSNorm( + hidden_size=self.q_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + self.k_norm = RMSNorm( + hidden_size=self.kv_size * self.tp_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) def apply_qk_norm(self, q, k): if self.qkv_proj.mapping.tp_size > 1: - # collect q and k from all gpus - from ..distributed import allgather - - temp_q = allgather(q, self.qkv_proj.mapping) - temp_k = allgather(k, self.qkv_proj.mapping) - temp_q = self.q_norm(temp_q) - temp_k = self.k_norm(temp_k) - q = temp_q.reshape(-1, self.tp_size, self.q_size)[:, self.tp_rank, :].reshape( - -1, self.q_size - ) - k = temp_k.reshape(-1, self.tp_size, self.kv_size)[:, self.tp_rank, :].reshape( - -1, self.kv_size + q = q.contiguous() + k = k.contiguous() + q, k = self.q_norm.minimax_all_reduce_rms.forward_qk( + q, k, self.q_norm.weight, self.k_norm.weight, self.q_norm.eps ) else: q = self.q_norm(q) k = self.k_norm(k) - return q, k def apply_rope( diff --git a/tests/microbenchmarks/minimax_all_reduce.py b/tests/microbenchmarks/minimax_all_reduce.py new file mode 100644 index 000000000000..21c4c8739a03 --- /dev/null +++ b/tests/microbenchmarks/minimax_all_reduce.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from itertools import product + +# isort: off +import torch +import pandas as pd + +# isort: on +try: + from cuda.bindings import runtime as cudart +except ImportError: + from cuda import cudart + +import tensorrt_llm as tllm +from tensorrt_llm import Mapping +from tensorrt_llm._torch.distributed import MiniMaxAllReduceRMS +from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, mpi_barrier +from tensorrt_llm.logger import logger +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper + +# MiniMax all-reduce only uses D (hidden_size) 128 and 1536 in practice. +ALLOWED_HIDDEN_SIZES = (256, 1536) + +# Q+K fused API benchmark dimensions +QK_Q_DIM = 1536 +QK_K_DIM = 256 + + +def profile_minimax_allreduce_rms( + mapping: Mapping, + op: MiniMaxAllReduceRMS, + warmup: int = 10, + iters: int = 100, + inner_loop: int = 8, + input_tensor=None, + norm_weight=None, + eps: float = 1e-5, +): + def func(): + for _ in range(inner_loop): + op(input_tensor, norm_weight, eps) + + for _ in range(warmup): + for i in range(inner_loop): + op(input_tensor, norm_weight, eps) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + func() + + graph.replay() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + graph.replay() + start.record() + for _ in range(iters): + graph.replay() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1000.0 / (iters * inner_loop) + + +def profile_minimax_allreduce_rms_qk( + mapping: Mapping, + op: MiniMaxAllReduceRMS, + warmup: int = 10, + iters: int = 100, + inner_loop: int = 8, + q_tensor=None, + k_tensor=None, + norm_weight_q=None, + norm_weight_k=None, + eps: float = 1e-5, +): + """Profile the fused Q+K minimax allreduce RMS API (forward_qk).""" + + def func(): + for _ in range(inner_loop): + op.forward_qk(q_tensor, k_tensor, norm_weight_q, norm_weight_k, eps) + + for _ in range(warmup): + func() + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + func() + + graph.replay() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + graph.replay() + start.record() + for _ in range(iters): + graph.replay() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1000.0 / (iters * inner_loop) + + +def minimax_allreduce_benchmark( + dtype: str = "bfloat16", + test_range: str = "256,256000000,10", + explore_2d: bool = False, + save_csv: str = None, + warmup: int = 10, + iters: int = 100, +): + world_size = tllm.mpi_world_size() + rank = tllm.mpi_rank() + local_rank = local_mpi_rank() + gpus_per_node = local_mpi_size() + + torch.cuda.set_device(local_rank) + cudart.cudaSetDevice(local_rank) + + mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) + logger.set_rank(mapping.rank) + + if world_size == 1: + raise RuntimeError("Benchmark must run with mpi_world_size > 1") + + torch_dtype = tllm._utils.str_dtype_to_torch(dtype) + + inner_loop = 8 + eps = 1e-5 + + shape_list = [] + if explore_2d: + num_tokens_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] + hidden_size_list = list(ALLOWED_HIDDEN_SIZES) + for num_tokens, hidden_size in product(num_tokens_list, hidden_size_list): + shape_list.append((num_tokens, hidden_size)) + else: + min_size, max_size, ratio = [int(i) for i in test_range.split(",")] + for hidden_size in ALLOWED_HIDDEN_SIZES: + n_min = max(1, (min_size + hidden_size - 1) // hidden_size) + n_max = max_size // hidden_size + num_tokens = n_min + while num_tokens <= n_max: + shape_list.append((num_tokens, hidden_size)) + num_tokens *= ratio + num_tokens = max(num_tokens, 1) + # Only test D (hidden_size) = 128 and 1536 (no-op when explore_2d already uses them) + shape_list = [(n, d) for n, d in shape_list if d in ALLOWED_HIDDEN_SIZES] + + op = MiniMaxAllReduceRMS(mapping=mapping) + max_workspace = CustomAllReduceHelper.max_workspace_size_auto( + mapping.tp_size, support_deterministic=False + ) + + df = pd.DataFrame() + for num_tokens, hidden_size in shape_list: + message_size_bytes = num_tokens * hidden_size * torch.finfo(torch_dtype).bits // 8 + if message_size_bytes > max_workspace: + continue + + input_tensor = torch.ones((num_tokens, hidden_size), dtype=torch_dtype, device="cuda") + norm_weight = torch.randn((hidden_size,), dtype=torch_dtype, device="cuda") + + mpi_barrier() + median_us = profile_minimax_allreduce_rms( + mapping=mapping, + op=op, + warmup=warmup, + iters=iters, + inner_loop=inner_loop, + input_tensor=input_tensor, + norm_weight=norm_weight, + eps=eps, + ) + + if mapping.rank == 0: + df = pd.concat( + [ + df, + pd.DataFrame( + { + "world_size": [mapping.world_size], + "dtype": [dtype], + "api": ["single"], + "message_size_bytes": [message_size_bytes], + "num_tokens": [num_tokens], + "hidden_size": [hidden_size], + "q_dim": [pd.NA], + "k_dim": [pd.NA], + "time (us)": [median_us], + } + ), + ] + ) + print(f"num_tokens: {num_tokens}, hidden_size: {hidden_size}, time (us): {median_us}") + + # Q+K fused API benchmark: q_dim=1536, k_dim=128 + num_tokens_qk = sorted({n for n, _ in shape_list}) + for num_tokens in num_tokens_qk: + q_tensor = torch.ones((num_tokens, QK_Q_DIM), dtype=torch_dtype, device="cuda") + k_tensor = torch.ones((num_tokens, QK_K_DIM), dtype=torch_dtype, device="cuda") + norm_weight_q = torch.randn((QK_Q_DIM,), dtype=torch_dtype, device="cuda") + norm_weight_k = torch.randn((QK_K_DIM,), dtype=torch_dtype, device="cuda") + message_size_bytes_qk = ( + num_tokens * (QK_Q_DIM + QK_K_DIM) * torch.finfo(torch_dtype).bits // 8 + ) + if message_size_bytes_qk > max_workspace: + continue + + mpi_barrier() + median_us_qk = profile_minimax_allreduce_rms_qk( + mapping=mapping, + op=op, + warmup=warmup, + iters=iters, + inner_loop=inner_loop, + q_tensor=q_tensor, + k_tensor=k_tensor, + norm_weight_q=norm_weight_q, + norm_weight_k=norm_weight_k, + eps=eps, + ) + + if mapping.rank == 0: + df = pd.concat( + [ + df, + pd.DataFrame( + { + "world_size": [mapping.world_size], + "dtype": [dtype], + "api": ["qk"], + "message_size_bytes": [message_size_bytes_qk], + "num_tokens": [num_tokens], + "hidden_size": [pd.NA], + "q_dim": [QK_Q_DIM], + "k_dim": [QK_K_DIM], + "time (us)": [median_us_qk], + } + ), + ] + ) + print( + f"qk: num_tokens: {num_tokens}, q_dim: {QK_Q_DIM}, k_dim: {QK_K_DIM}, " + f"time (us): {median_us_qk}" + ) + + if mapping.rank == 0: + pd.set_option("display.max_rows", None) + pd.set_option("display.max_columns", None) + pd.set_option("display.width", None) + pd.set_option("display.max_colwidth", None) + print(df) + + if mapping.rank == 0 and save_csv is not None: + df.to_csv(save_csv, index=False) + + return df + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--dtype", "-t", default="bfloat16") + parser.add_argument( + "--range", + "-r", + default="256,256000000,10", + help="min_size,max_size,multiplicative_ratio", + ) + parser.add_argument("--explore_2d", action="store_true", default=False) + parser.add_argument("--save_csv", type=str, default=None) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=100) + + args = parser.parse_args() + + minimax_allreduce_benchmark( + args.dtype, + args.range, + args.explore_2d, + args.save_csv, + args.warmup, + args.iters, + ) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index a3ba58faaed4..4893c15f76a0 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,7 +27,8 @@ from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, AllReduceStrategy, - MoEAllReduce, MoEAllReduceParams) + MiniMaxAllReduceRMS, MoEAllReduce, + MoEAllReduceParams) from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm.mapping import Mapping @@ -697,3 +698,230 @@ def test_moe_finalize_allreduce_no_residual(mpi_pool_executor): ) for r in results: assert r is True + + +@torch.inference_mode() +def run_minimax_allreduce_rms_op(input: torch.Tensor, tensor_parallel_size: int, + tensor_parallel_rank: int, + rms_weights: torch.Tensor, eps: float): + torch.manual_seed(42) + + total_tokens = input.shape[0] + origin_dtype = input.dtype + + input = input.cuda() + rms_weights = rms_weights.cuda() + + rank_input = input.reshape(total_tokens, tensor_parallel_size, + -1)[:, tensor_parallel_rank, :].contiguous() + rank_rms_weights = rms_weights.reshape( + tensor_parallel_size, -1)[tensor_parallel_rank, :].contiguous() + + # firstly, calculate the reference output for each rank + ref_output = rms_norm(input.to(torch.float32), + rms_weights.to(torch.float32), eps) + ref_output = ref_output.reshape(total_tokens, tensor_parallel_size, + -1).to(origin_dtype) + ref_output = ref_output[:, tensor_parallel_rank, :] + + # then, calculate the minimax allreduce output + minimax_allreduce_rms = MiniMaxAllReduceRMS(mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + )) + minimax_output = minimax_allreduce_rms( + input=rank_input.to(torch.bfloat16), + rms_weights=rank_rms_weights.to(torch.bfloat16), + eps=eps, + ) + # finally, verify the results + torch.testing.assert_close(minimax_output, ref_output, rtol=0.2, atol=0.2) + + return rank_input + + +@torch.inference_mode() +def run_minimax_allreduce_rms_single_rank(tensor_parallel_size, + single_rank_forward_func, input, + rms_weights, eps): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + single_rank_forward_func(input, tensor_parallel_size, rank, rms_weights, + eps) + except Exception: + traceback.print_exc() + raise + return True + + +@torch.inference_mode() +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) +def test_minimax_allreduce_rms(mpi_pool_executor): + torch.manual_seed(42) + + seq_len = 16 + hidden_size = 6144 + dtype = torch.bfloat16 + tensor_parallel_size = mpi_pool_executor.num_workers + + token_input = torch.randn((seq_len, hidden_size), dtype=dtype) + rms_weights = torch.randn((hidden_size, ), dtype=dtype, device="cuda") + eps = 1e-5 + + results = mpi_pool_executor.map( + run_minimax_allreduce_rms_single_rank, + *zip(*[(tensor_parallel_size, run_minimax_allreduce_rms_op, token_input, + rms_weights, eps)] * tensor_parallel_size), + ) + for r in results: + assert r is True + + +@torch.inference_mode() +def run_minimax_allreduce_rms_qk_op( + q_input: torch.Tensor, k_input: torch.Tensor, tensor_parallel_size: int, + tensor_parallel_rank: int, rms_weights_q: torch.Tensor, + rms_weights_k: torch.Tensor, eps: float, non_contiguous_input: bool): + torch.manual_seed(42) + + num_tokens = q_input.shape[0] + origin_dtype = q_input.dtype + + q_input = q_input.cuda() + k_input = k_input.cuda() + rms_weights_q = rms_weights_q.cuda() + rms_weights_k = rms_weights_k.cuda() + + # firstly, calculate the reference output for each rank + # Reference: each rank computes RMS norm on its local Q/K independently, + # then all-reduce is applied to the variance before normalization + # The all-reduce sum happens across TP ranks, so we need to simulate that + q_input_fp32 = q_input.to(torch.float32) + k_input_fp32 = k_input.to(torch.float32) + + # Compute reference: RMS norm with all-reduced variance + q_variance_sum = q_input_fp32.pow(2).mean(-1, keepdim=True) + k_variance_sum = k_input_fp32.pow(2).mean(-1, keepdim=True) + + ref_q_output = q_input_fp32 * torch.rsqrt(q_variance_sum + eps) + ref_k_output = k_input_fp32 * torch.rsqrt(k_variance_sum + eps) + + # Apply weights + ref_q_output = ref_q_output * rms_weights_q.to(torch.float32) + ref_k_output = ref_k_output * rms_weights_k.to(torch.float32) + + ref_q_output = ref_q_output.to(origin_dtype) + ref_k_output = ref_k_output.to(origin_dtype) + + # we only need to compare the reference output of the current rank + ref_q_output = ref_q_output.reshape(num_tokens, tensor_parallel_size, -1) + ref_k_output = ref_k_output.reshape(num_tokens, tensor_parallel_size, -1) + ref_q_output = ref_q_output[:, tensor_parallel_rank, :].contiguous() + ref_k_output = ref_k_output[:, tensor_parallel_rank, :].contiguous() + + # minimax input should be sliced by rank + q_input = q_input.reshape(num_tokens, tensor_parallel_size, -1) + k_input = k_input.reshape(num_tokens, tensor_parallel_size, -1) + rank_q_input = q_input[:, tensor_parallel_rank, :].contiguous() + rank_k_input = k_input[:, tensor_parallel_rank, :].contiguous() + if non_contiguous_input: + # Mimic the integration path where q and k are views split from + # the local qkv shard. + rank_v_input = torch.zeros_like(rank_k_input) + rank_qkv_input = torch.cat([rank_q_input, rank_k_input, rank_v_input], + dim=-1) + rank_q_input, rank_k_input, _ = rank_qkv_input.split( + [ + rank_q_input.shape[-1], + rank_k_input.shape[-1], + rank_v_input.shape[-1], + ], + dim=-1, + ) + assert not rank_q_input.is_contiguous() + assert not rank_k_input.is_contiguous() + # MiniMaxM2Attention materializes the split views before calling the custom op. + rank_q_input = rank_q_input.contiguous() + rank_k_input = rank_k_input.contiguous() + + # rms weights should be sliced by rank + rms_weights_q = rms_weights_q.reshape(tensor_parallel_size, -1) + rms_weights_k = rms_weights_k.reshape(tensor_parallel_size, -1) + rank_rms_weights_q = rms_weights_q[tensor_parallel_rank, :].contiguous() + rank_rms_weights_k = rms_weights_k[tensor_parallel_rank, :].contiguous() + + # then, calculate the minimax allreduce output + minimax_allreduce_rms = MiniMaxAllReduceRMS(mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + )) + minimax_q_output, minimax_k_output = minimax_allreduce_rms.forward_qk( + q=rank_q_input, + k=rank_k_input, + rms_weights_q=rank_rms_weights_q, + rms_weights_k=rank_rms_weights_k, + eps=eps, + ) + + # finally, verify the results + torch.testing.assert_close(minimax_q_output, + ref_q_output, + rtol=0.2, + atol=0.2) + torch.testing.assert_close(minimax_k_output, + ref_k_output, + rtol=0.2, + atol=0.2) + + return q_input + + +@torch.inference_mode() +def run_minimax_allreduce_rms_qk_single_rank(tensor_parallel_size, + single_rank_forward_func, q_input, + k_input, rms_weights_q, + rms_weights_k, eps, + non_contiguous_input): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + single_rank_forward_func(q_input, k_input, tensor_parallel_size, rank, + rms_weights_q, rms_weights_k, eps, + non_contiguous_input) + except Exception: + traceback.print_exc() + raise + return True + + +@pytest.mark.skipif(torch.cuda.device_count() != 4, + reason="Requires exactly 4 GPUs for this test") +@pytest.mark.parametrize("non_contiguous_input", [False, True], + ids=["contiguous", "split_qkv_view"]) +@pytest.mark.parametrize("mpi_pool_executor", [4], indirect=True) +def test_minimax_allreduce_rms_qk(mpi_pool_executor, non_contiguous_input): + torch.manual_seed(42) + + seq_len = 1024 + q_size = 6144 + k_size = 1024 + dtype = torch.bfloat16 + tensor_parallel_size = mpi_pool_executor.num_workers + + q_input = torch.randn((seq_len, q_size), dtype=dtype) + k_input = torch.randn((seq_len, k_size), dtype=dtype) + rms_weights_q = torch.randn((q_size, ), dtype=dtype, device="cuda") + rms_weights_k = torch.randn((k_size, ), dtype=dtype, device="cuda") + eps = 1e-5 + + results = mpi_pool_executor.map( + run_minimax_allreduce_rms_qk_single_rank, + *zip(*[(tensor_parallel_size, run_minimax_allreduce_rms_qk_op, q_input, + k_input, rms_weights_q, rms_weights_k, eps, + non_contiguous_input)] * tensor_parallel_size), + ) + for r in results: + assert r is True