From c43326ba42a00d148ace1a62e39a893e1af6ab73 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sun, 3 May 2026 10:25:38 -0700 Subject: [PATCH 1/3] fix: support 7168 fused MHC hidden size Signed-off-by: Oseltamivir --- .../kernels/mhcKernels/mhcFusedHcKernel.cu | 215 +++++++++++++++--- .../kernels/mhcKernels/mhcKernels.h | 10 +- tensorrt_llm/_torch/modules/mhc/mhc_cuda.py | 62 +++-- 3 files changed, 234 insertions(+), 53 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu b/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu index 334961bdfced..23c15c72a9bc 100644 --- a/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu +++ b/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu @@ -91,9 +91,11 @@ inline void fhcZeroWorkspaces(float* y_acc, uint32_t y_elems, float* r_acc, uint } // namespace // ---- mHC fused kernel shape constants (mirrors the Python module) ---- -static constexpr uint32_t FHC_SHAPE_N = 24; // HC_MULT * (2 + HC_MULT) = 4 * 6 = 24 -static constexpr uint32_t FHC_HIDDEN = 4096; // only this hidden size is currently wired up +// HC_MULT * (2 + HC_MULT) = 4 * 6 = 24. +static constexpr uint32_t FHC_SHAPE_N = 24; static constexpr uint32_t FHC_HC_MULT = 4; +static constexpr uint32_t FHC_HIDDEN_FLASH = 4096; +static constexpr uint32_t FHC_HIDDEN_PRO = 7168; static constexpr uint32_t FHC_BLOCK_M = 64; static constexpr uint32_t FHC_BLOCK_N = 32; static constexpr uint32_t FHC_BLOCK_K = 64; @@ -103,6 +105,38 @@ static constexpr uint32_t FHC_N_INPUT_STG = 2; static constexpr uint32_t FHC_NUM_MMA_TH = 128; static constexpr uint32_t FHC_NUM_PMAP_TH = 128; +template