diff --git a/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu b/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu index 334961bdfced..7d5d2b6daf45 100644 --- a/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu +++ b/cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu @@ -91,9 +91,11 @@ inline void fhcZeroWorkspaces(float* y_acc, uint32_t y_elems, float* r_acc, uint } // namespace // ---- mHC fused kernel shape constants (mirrors the Python module) ---- -static constexpr uint32_t FHC_SHAPE_N = 24; // HC_MULT * (2 + HC_MULT) = 4 * 6 = 24 -static constexpr uint32_t FHC_HIDDEN = 4096; // only this hidden size is currently wired up +// HC_MULT * (2 + HC_MULT) = 4 * 6 = 24. +static constexpr uint32_t FHC_SHAPE_N = 24; static constexpr uint32_t FHC_HC_MULT = 4; +static constexpr uint32_t FHC_HIDDEN_FLASH = 4096; +static constexpr uint32_t FHC_HIDDEN_PRO = 7168; static constexpr uint32_t FHC_BLOCK_M = 64; static constexpr uint32_t FHC_BLOCK_N = 32; static constexpr uint32_t FHC_BLOCK_K = 64; @@ -103,6 +105,38 @@ static constexpr uint32_t FHC_N_INPUT_STG = 2; static constexpr uint32_t FHC_NUM_MMA_TH = 128; static constexpr uint32_t FHC_NUM_PMAP_TH = 128; +template