From c5bc8e182f52384b1f21993beeb96f25a6bde5dc Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Mon, 9 Feb 2026 08:42:19 +0000 Subject: [PATCH 1/6] Add a copy of test_fused_moe.py Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../_torch/modules/test_fused_moe_jiangs.py | 781 ++++++++++++++++++ 1 file changed, 781 insertions(+) create mode 100644 tests/unittest/_torch/modules/test_fused_moe_jiangs.py diff --git a/tests/unittest/_torch/modules/test_fused_moe_jiangs.py b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py new file mode 100644 index 000000000000..2db56644956d --- /dev/null +++ b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py @@ -0,0 +1,781 @@ +import os +import pickle +import sys +from contextlib import contextmanager +from itertools import product +from typing import Dict, List, Optional +from unittest import mock + +import _torch.helpers +import cloudpickle +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, + per_block_cast_to_fp8_e8m0, + per_token_cast_to_fp8_e8m0) +from mpi4py import MPI +from mpi4py.futures import MPIPoolExecutor +from transformers.configuration_utils import PretrainedConfig +from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce, + skip_neither_ada_nor_hopper_unittest, skip_no_hopper, + skip_pre_blackwell, skip_pre_hopper) + +from tensorrt_llm._torch.autotuner import AutoTuner, autotune +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ + CuteDslFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ + DeepGemmFusedMoE +from tensorrt_llm._torch.modules.fused_moe.interface import ( + AlltoallMethodType, MoEWeightLoadingMode) + +# isort and yapf will fight against each other here, so we disable isort +# isort: off +from tensorrt_llm._torch.modules.fused_moe import ( + BaseMoeRoutingMethod, CutlassFusedMoE, TRTLLMGenFusedMoE, + DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, TritonFusedMoE, + create_moe, WideEPMoE) +from tensorrt_llm._torch.modules.fused_moe.quantization import \ + NVFP4CutlassFusedMoEMethod +# isort: on +from tensorrt_llm._torch.modules.gated_mlp import GatedMLP +from tensorrt_llm._utils import get_sm_version, mpi_rank +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + +cloudpickle.register_pickle_by_value(sys.modules[__name__]) +cloudpickle.register_pickle_by_value(_torch.helpers) +MPI.pickle.__init__( + cloudpickle.dumps, + cloudpickle.loads, + pickle.HIGHEST_PROTOCOL, +) + + +@skip_neither_ada_nor_hopper_unittest +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM]) +def test_fused_moe_w4afp8(dtype, weight_loading_mode): + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = 768 + INTERMEDIATE_SIZE = 640 + SCALING_GROUP_SIZE = 128 + NUM_EXPERTS = 3 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + affine_coeff = 0.005 + + lut = { + "weight": + "weight", + "weight_scale": + ("weight_scale_inv" if weight_loading_mode + == MoEWeightLoadingMode.W4A8_CUSTOM else "weight_scale"), + "weight_scale_2": + "weight_scale_2", + "pre_quant_scale": + "pre_quant_scale", + "input_scale": + "input_scale", + } + + weights = {} + for expert_id in range(NUM_EXPERTS): + # ModelOpt W4A8 packs pairs of 4b weights in the output dimension into one 8b element. + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_shape = (INTERMEDIATE_SIZE // 2, HIDDEN_SIZE) + w2_shape = (HIDDEN_SIZE // 2, INTERMEDIATE_SIZE) + w3_shape = (INTERMEDIATE_SIZE // 2, HIDDEN_SIZE) + # The custom W4A8 quantization script examples/quantization/quantize_mixed_precision_moe.py + # packs pairs of 4b weight in the input dimension into one 8b element. + if weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + w1_shape = (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2) + w2_shape = (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2) + w3_shape = (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2) + + # The weights in int4 precision. + w1_weight = torch.randint(-128, 127, w1_shape, + dtype=torch.int8).cuda() + w2_weight = torch.randint(-128, 127, w2_shape, + dtype=torch.int8).cuda() + w3_weight = torch.randint(-128, 127, w3_shape, + dtype=torch.int8).cuda() + + # The pre-quant scale to be multiplied with the input activation. + # Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles + # non-uniform pre-quant scaling factors correctly + w1_pre_quant_scale = torch.rand( + HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + w2_pre_quant_scale = torch.rand( + INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + w3_pre_quant_scale = torch.rand( + HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + + # The weight scale to dequantize int4 weights (by multiplication). + w1_scale = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + w2_scale = torch.randn( + (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + w3_scale = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda") * affine_coeff + + # The input scale to quantize the input activation (by division). + w1_input_scale = torch.randn(1, dtype=torch.float32, + device="cuda") * 0.2 + w2_input_scale = w1_input_scale + w3_input_scale = w1_input_scale + + # The weight scale 2 to quantize the dequantized weights (by division). + w1_weight_scale_2 = torch.ones([1], + dtype=torch.float32, + device="cuda") + w2_weight_scale_2 = w1_weight_scale_2 + w3_weight_scale_2 = w1_weight_scale_2 + + # Prepare weights. + weights[f"{expert_id}.w1.{lut['weight']}"] = w1_weight + weights[f"{expert_id}.w2.{lut['weight']}"] = w2_weight + weights[f"{expert_id}.w3.{lut['weight']}"] = w3_weight + weights[f"{expert_id}.w1.{lut['input_scale']}"] = w1_input_scale + weights[f"{expert_id}.w2.{lut['input_scale']}"] = w2_input_scale + weights[f"{expert_id}.w3.{lut['input_scale']}"] = w3_input_scale + weights[f"{expert_id}.w1.{lut['weight_scale']}"] = w1_scale + weights[f"{expert_id}.w2.{lut['weight_scale']}"] = w2_scale + weights[f"{expert_id}.w3.{lut['weight_scale']}"] = w3_scale + weights[ + f"{expert_id}.w1.{lut['pre_quant_scale']}"] = w1_pre_quant_scale + weights[ + f"{expert_id}.w2.{lut['pre_quant_scale']}"] = w2_pre_quant_scale + weights[ + f"{expert_id}.w3.{lut['pre_quant_scale']}"] = w3_pre_quant_scale + weights[ + f"{expert_id}.w1.{lut['weight_scale_2']}"] = w1_weight_scale_2 + weights[ + f"{expert_id}.w2.{lut['weight_scale_2']}"] = w2_weight_scale_2 + weights[ + f"{expert_id}.w3.{lut['weight_scale_2']}"] = w3_weight_scale_2 + + quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ) + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=False, + model_config=ModelConfig(quant_config=quant_config), + weight_loading_mode=weight_loading_mode) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + + # weights + def unpack_weights(weight: torch.Tensor) -> torch.Tensor: + unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + return unpacker(weight.cpu().T.contiguous()).cuda() + else: + return unpacker(weight.cpu()).T.contiguous().cuda() + + w1 = unpack_weights(weights[f"{e_idx}.w1.{lut['weight']}"]) + w2 = unpack_weights(weights[f"{e_idx}.w2.{lut['weight']}"]) + w3 = unpack_weights(weights[f"{e_idx}.w3.{lut['weight']}"]) + w3_w1 = torch.cat([w3, w1], dim=-1) + + # weight_scale + s1 = weights[f"{e_idx}.w1.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s2 = weights[f"{e_idx}.w2.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s3 = weights[f"{e_idx}.w3.{lut['weight_scale']}"].T.contiguous( + ).cuda() + s3_s1 = torch.cat([s3, s1], dim=-1) + + # input_scale + p1 = weights[f"{e_idx}.w1.{lut['input_scale']}"].cuda() + p2 = weights[f"{e_idx}.w2.{lut['input_scale']}"].cuda() + p3 = weights[f"{e_idx}.w3.{lut['input_scale']}"].cuda() + p3_p1 = torch.max(p1, p3) + + # pre_quant_scale + a1 = a2 = a3 = a1_a3 = None + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + a1 = weights[ + f"{e_idx}.w1.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a2 = weights[ + f"{e_idx}.w2.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a3 = weights[ + f"{e_idx}.w3.{lut['pre_quant_scale']}"].T.contiguous( + ).cuda() + a1_a3 = torch.max(a1, a3) + + # weight_scale_2 + q1 = q2 = q3 = q3_q1 = None + if weight_loading_mode == MoEWeightLoadingMode.VANILLA: + q1 = weights[f"{e_idx}.w1.{lut['weight_scale_2']}"].cuda() + q2 = weights[f"{e_idx}.w3.{lut['weight_scale_2']}"].cuda() + q3 = weights[f"{e_idx}.w2.{lut['weight_scale_2']}"].cuda() + q3_q1 = torch.max(q3, q1) + + # forward pass + def process_layer( + act, + weight, + weight_scale, + input_scale, + pre_quant_scale=None, + weight_scale_2=None, + ): + if pre_quant_scale is not None: + act = act * pre_quant_scale + act = (torch.clamp((act / input_scale), -448.0, + 448.0).to(torch.float8_e4m3fn).to(dtype)) + weight = (weight.float() * weight_scale.repeat_interleave( + 128, dim=0).float()).to(dtype) + if weight_scale_2 is not None: + weight /= weight_scale_2 + output = torch.matmul(act, weight) * input_scale + if weight_scale_2 is not None: + output *= weight_scale_2 + return output + + # fc13 + fc1 = process_layer( + act, + w3_w1, + s3_s1, + p3_p1, + pre_quant_scale=a1_a3, + weight_scale_2=q3_q1, + ) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + + # fc2 + fc2 = process_layer(fc1, + w2, + s2, + p2, + pre_quant_scale=a2, + weight_scale_2=q2) + + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results + + AutoTuner.get().clear_cache() + with torch.inference_mode(): + ref_output = ref() + + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + # Explicitly capture context for kernel testing + with AutoTuner.get().capture() as all_tactics, torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + + # Test all kernel tactics + for tactic in all_tactics: + with AutoTuner.get().replay(tactic), torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + # assert that result does not contain NaN or is all 0s + assert not torch.isnan(output).any(), "output contains NaN" + assert torch.nonzero(output).numel() > 0, "output is empty" + torch.testing.assert_close(output, + ref_output, + rtol=1e-2, + atol=0.1) + + torch.cuda.synchronize() + assert not torch.isnan(ref_output).any(), "ref_output contains NaN" + assert not torch.isnan(output).any(), "output contains NaN" + assert torch.nonzero(output).numel() > 0, "output is empty" + assert torch.nonzero(ref_output).numel() > 0, "ref_output is empty" + # Final comparison + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [768, 2880]) +@pytest.mark.parametrize( + "moe_backend", + [ + # smVersion + pytest.param("TRTLLM", + marks=[skip_blackwell_geforce, skip_pre_blackwell]), + pytest.param( + "CUTLASS", + marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]), + ], +) +def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + SEQ_LEN = 4 + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 640 + SCALING_GROUP_SIZE = 32 + NUM_EXPERTS = 4 + TOP_K = 2 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randint(0, + 256, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), + dtype=torch.uint8, + device='cuda') + w2_weight = torch.randint(0, + 256, + (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2), + dtype=torch.uint8, + device='cuda') + w3_weight = torch.randint(0, + 256, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), + dtype=torch.uint8, + device='cuda') + + w1_scale = torch.randint( + 118, + 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + w2_scale = torch.randint( + 118, + 123, (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + w3_scale = torch.randint( + 118, + 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=torch.uint8, + device='cuda') + + weights[f"{expert_id}.w1.weight"] = w1_weight + weights[f"{expert_id}.w2.weight"] = w2_weight + weights[f"{expert_id}.w3.weight"] = w3_weight + # WFP4A16FusedMoEMethod + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale + # MXFP4WeightFusedMoEMethod + weights[f"{expert_id}.w1.weight_scale"] = w1_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_MXFP4) + + # Create pretrained_config with necessary parameters + pretrained_config = PretrainedConfig() + pretrained_config.num_experts = NUM_EXPERTS + pretrained_config.hidden_size = HIDDEN_SIZE + pretrained_config.intermediate_size = INTERMEDIATE_SIZE + pretrained_config.torch_dtype = dtype + + fused_moe = create_moe(routing_method=routing_method, + reduce_results=False, + model_config=ModelConfig( + pretrained_config=pretrained_config, + quant_config=quant_config, + moe_backend=moe_backend)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + def ref(): + results = torch.zeros_like(x) + selected_experts, final_scales = routing_method.apply(router_logits) + unpacker = torch.ops.trtllm.mxfp4_dequantize_unswizzled + for e_idx in range(NUM_EXPERTS): + mask = selected_experts == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (final_scales * + mask).sum(1)[activated_tokens].unsqueeze(1) + + # weights and scales + w1 = weights[f"{e_idx}.w1.weight"] + s1 = weights[f"{e_idx}.w1.weight_scale_inv"] + w2 = weights[f"{e_idx}.w2.weight"] + s2 = weights[f"{e_idx}.w2.weight_scale_inv"] + w3 = weights[f"{e_idx}.w3.weight"] + s3 = weights[f"{e_idx}.w3.weight_scale_inv"] + + # converted weights + w1 = unpacker(w1.cpu(), s1.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w2 = unpacker(w2.cpu(), s2.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w3 = unpacker(w3.cpu(), s3.cpu(), SCALING_GROUP_SIZE).to( + dtype=x.dtype, device=x.device).T.contiguous() + w3_w1 = torch.cat([w3, w1], dim=-1) + + fc1 = torch.matmul(act, w3_w1) + fc1, gate = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + fc2 = torch.matmul(fc1, w2) + results[activated_tokens, :] += (fc2 * final_scale).to( + results.dtype) + return results + + AutoTuner.get().clear_cache() + with torch.inference_mode(): + ref_output = ref() + + with torch.inference_mode(), autotune(): + fused_moe.forward(x, router_logits) + + # Explicitly capture context for kernel testing + with AutoTuner.get().capture() as all_tactics, torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + + # Test all kernel tactics + for tactic in all_tactics: + with AutoTuner.get().replay(tactic), torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + check_accuracy(output, + ref_output, + rtol=1e-2, + atol=0.1, + percent=0.99) + + # compare + torch.cuda.synchronize() + check_accuracy(output, ref_output, rtol=1e-2, atol=0.1, percent=0.99) + + +@skip_no_hopper +@pytest.mark.parametrize("experts", [8, 128]) +@pytest.mark.parametrize( + "hidden_size, intermediate_size", + [ + (2880, 2880), + (2880, 1440), + (2880, 720), + (2880, 360), + ], +) +@pytest.mark.parametrize("fp8_activation", [True, False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("dynamic_quant", [True, False]) +def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size, + fp8_activation, bias, dynamic_quant): + if fp8_activation: + pytest.skip("Latest Triton requires BF16 activation on Hopper") + + mapping = Mapping() + mapping.rank = mpi_rank() + + with torch.device(f'cuda:{mapping.rank}'): + dtype = torch.bfloat16 + SEQ_LEN = 8 + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = intermediate_size + NUM_EXPERTS = experts + TOP_K = 4 + routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + w1_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype).cuda() + w2_weight = torch.randn((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w3_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype).cuda() + w1_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w2_bias = torch.randn((NUM_EXPERTS, HIDDEN_SIZE), dtype=dtype).cuda() + w3_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + + from triton_kernels.numerics_details.mxfp import ( + downcast_to_mxfp_torch, upcast_from_mxfp_torch) + + def fp32_to_mxfp4(tensor): + tensor = tensor.transpose(1, 2).contiguous() + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, + torch.uint8, + axis=1) + tensor_fp4 = tensor_fp4.transpose(1, 2).contiguous() + tensor_scales = tensor_scales.transpose(1, 2).contiguous() + return tensor_fp4, tensor_scales + + def mxfp4_to_fp32(tensor, scales): + tensor = tensor.transpose(1, 2).contiguous() + scales = scales.transpose(1, 2).contiguous() + tensor = upcast_from_mxfp_torch(tensor, + scales, + torch.float32, + axis=1) + return tensor.transpose(1, 2).contiguous() + + w1_weight_fp4, w1_weight_scale = fp32_to_mxfp4(w1_weight) + w2_weight_fp4, w2_weight_scale = fp32_to_mxfp4(w2_weight) + w3_weight_fp4, w3_weight_scale = fp32_to_mxfp4(w3_weight) + w1_weight_qdq = mxfp4_to_fp32(w1_weight_fp4, w1_weight_scale) + w2_weight_qdq = mxfp4_to_fp32(w2_weight_fp4, w2_weight_scale) + w3_weight_qdq = mxfp4_to_fp32(w3_weight_fp4, w3_weight_scale) + + # Since we don't have mxfp4 reference, we run the ref in bf16 after q-dq + weights = {} + for expert_id in range(NUM_EXPERTS): + weights[f"{expert_id}.w1.weight"] = w1_weight_qdq[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_qdq[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_qdq[expert_id] + if bias: + weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] + weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] + weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] + + ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(), + bias=bias) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + with torch.inference_mode(): + ref_output = ref_fused_moe.forward(x, router_logits) + torch.cuda.synchronize() + + # Now we run the TritonFusedMoE with MXFP4 weights + weights = {} + + for expert_id in range(NUM_EXPERTS): + if dynamic_quant: + weights[f"{expert_id}.w1.weight"] = w1_weight_qdq[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_qdq[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_qdq[expert_id] + else: + weights[f"{expert_id}.w1.weight"] = w1_weight_fp4[expert_id] + weights[f"{expert_id}.w2.weight"] = w2_weight_fp4[expert_id] + weights[f"{expert_id}.w3.weight"] = w3_weight_fp4[expert_id] + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[ + expert_id] + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[ + expert_id] + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[ + expert_id] + if bias: + weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] + weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] + weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] + + quant_algo = QuantAlgo.W4A8_MXFP4_FP8 if fp8_activation else QuantAlgo.W4A16_MXFP4 + quant_config = QuantConfig(quant_algo=quant_algo) + fused_moe = TritonFusedMoE(num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + bias=bias, + model_config=ModelConfig( + quant_config=quant_config, + mapping=mapping)) + fused_moe.load_weights([weights]) + fused_moe.cuda() + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + torch.cuda.synchronize() + + # Evaluate outputs + + # There can be one off mismatch in the outputs due to different kernel implementations + # Here we check certain percent of the outputs are within the tolerance + check_accuracy(output, ref_output, rtol=0.6, atol=0.6, percent=0.945) + + +class RefGatedMLPFusedMoE(nn.Module): + + def __init__(self, + num_experts: int, + routing_method: BaseMoeRoutingMethod, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + model_config: ModelConfig = ModelConfig(), + use_cute_dsl_blockscaling_mm: bool = False, + bias=False, + swiglu_alpha: Optional[float] = None, + swiglu_beta: Optional[float] = None, + swiglu_limit: Optional[float] = None): + super().__init__() + self.num_experts = num_experts + self.routing_method = routing_method + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.bias = bias + + self.dtype = dtype + self.quant_config = model_config.quant_config + + def custom_swiglu(x): + gate, value = x.chunk(2, dim=-1) + if swiglu_limit is not None and swiglu_limit != float("inf"): + gate = gate.clamp(max=swiglu_limit) + value = value.clamp(min=-swiglu_limit, max=swiglu_limit) + + alpha = swiglu_alpha if swiglu_alpha is not None else 1.0 + gate_act = gate * torch.sigmoid(gate * alpha) + + beta = swiglu_beta if swiglu_beta is not None else 0.0 + + return gate_act * (value + beta) + + self.experts = nn.ModuleList([ + GatedMLP( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + bias=bias, + dtype=self.dtype, + config=model_config, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, + activation=custom_swiglu + if swiglu_alpha is not None else F.silu, + ) for _ in range(self.num_experts) + ]) + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> torch.Tensor: + assert hidden_states.shape[-1] == self.hidden_size + hidden_states = hidden_states.view(-1, self.hidden_size) + + selected_experts, routing_weights = self.routing_method.apply( + router_logits) + + final_hidden_states = torch.zeros(hidden_states.shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + for expert_id in range(self.num_experts): + if not torch.any(selected_experts == expert_id): + continue + batch_idx, nth_expert = torch.where(selected_experts == expert_id) + expert_inputs = hidden_states[batch_idx] + + output = self.experts[expert_id](expert_inputs) + final_hidden_states[batch_idx] += routing_weights[ + batch_idx, nth_expert, None] * output.float() + + final_hidden_states = final_hidden_states.reshape(hidden_states.shape) + return final_hidden_states + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + weights = weights[0] + + for expert in range(self.num_experts): + gate_up_proj_weights = [{}, {}] + down_proj_weights = [{}] + + gate_up_proj_weights[0]['weight'] = weights[f"{expert}.w1.weight"] + gate_up_proj_weights[1]['weight'] = weights[f"{expert}.w3.weight"] + down_proj_weights[0]['weight'] = weights[f"{expert}.w2.weight"] + if self.bias: + gate_up_proj_weights[0]['bias'] = weights[f"{expert}.w1.bias"] + gate_up_proj_weights[1]['bias'] = weights[f"{expert}.w3.bias"] + down_proj_weights[0]['bias'] = weights[f"{expert}.w2.bias"] + + if self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8: + gate_up_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]['weight_scale'] = weights[ + f"{expert}.w3.weight_scale"] + down_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w2.weight_scale"] + gate_up_proj_weights[0]['input_scale'] = weights[ + f"{expert}.w1.input_scale"] + gate_up_proj_weights[1]['input_scale'] = weights[ + f"{expert}.w3.input_scale"] + down_proj_weights[0]['input_scale'] = weights[ + f"{expert}.w2.input_scale"] + elif self.quant_config and self.quant_config.quant_algo in ( + QuantAlgo.NVFP4, QuantAlgo.W4A8_NVFP4_FP8): + gate_up_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]['weight_scale'] = weights[ + f"{expert}.w3.weight_scale"] + down_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w2.weight_scale"] + gate_up_proj_weights[0]['input_scale'] = weights[ + f"{expert}.w1.input_scale"] + gate_up_proj_weights[1]['input_scale'] = weights[ + f"{expert}.w3.input_scale"] + down_proj_weights[0]['input_scale'] = weights[ + f"{expert}.w2.input_scale"] + gate_up_proj_weights[0]['weight_scale_2'] = weights[ + f"{expert}.w1.weight_scale_2"] + gate_up_proj_weights[1]['weight_scale_2'] = weights[ + f"{expert}.w3.weight_scale_2"] + down_proj_weights[0]['weight_scale_2'] = weights[ + f"{expert}.w2.weight_scale_2"] + elif (self.quant_config and self.quant_config.quant_algo + == QuantAlgo.FP8_BLOCK_SCALES): + gate_up_proj_weights[0]["weight_scale"] = weights[ + f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]["weight_scale"] = weights[ + f"{expert}.w3.weight_scale"] + down_proj_weights[0]["weight_scale"] = weights[ + f"{expert}.w2.weight_scale"] + elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: + gate_up_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]['weight_scale'] = weights[ + f"{expert}.w3.weight_scale"] + down_proj_weights[0]['weight_scale'] = weights[ + f"{expert}.w2.weight_scale"] + + self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights) + self.experts[expert].down_proj.load_weights(down_proj_weights) + From 4e97d6b53a1ab4a94c7c9afd7cbe0fd774a29de4 Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:54:24 +0000 Subject: [PATCH 2/6] Check the best tactic of Hopper fp4 x bf16 grouped gemm Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- cpp/tensorrt_llm/thop/moeOp.cpp | 11 +++++++ .../_torch/modules/test_fused_moe_jiangs.py | 32 +++++++++++++++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 5e744804aae1..ed3725fdfcc7 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -665,6 +665,16 @@ class FusedMoeRunner : public torch::CustomClassHolder return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size(); } + std::string getTacticDesc(int64_t const gemm_idx, int64_t const tactic_id) + { + std::lock_guard lock(mMutex); + TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2"); + auto const& profiles = (gemm_idx == 1) ? mGemm1Profiles : mGemm2Profiles; + TORCH_CHECK(tactic_id >= 0 && tactic_id < static_cast(profiles.size()), + "tactic_id out of range: ", tactic_id, " >= ", profiles.size()); + return profiles[tactic_id].toString(); + } + // TODO Update this to be able to tell if we are profiling swiglu bias void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights, torch::optional const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, @@ -1209,6 +1219,7 @@ TORCH_LIBRARY(trtllm, m) .def(torch::init()) .def("run_gemm_profile", &tensorrt_llm::torch_ext::FusedMoeRunner::runGemmProfile) .def("get_tactic_num", &tensorrt_llm::torch_ext::FusedMoeRunner::getTacticNum) + .def("get_tactic_desc", &tensorrt_llm::torch_ext::FusedMoeRunner::getTacticDesc) .def("run_moe", &tensorrt_llm::torch_ext::FusedMoeRunner::runMoe) .def("run_moe_min_latency", &tensorrt_llm::torch_ext::FusedMoeRunner::runMoeMinLantency); } diff --git a/tests/unittest/_torch/modules/test_fused_moe_jiangs.py b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py index 2db56644956d..b94916a493cf 100644 --- a/tests/unittest/_torch/modules/test_fused_moe_jiangs.py +++ b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py @@ -10,6 +10,7 @@ import cloudpickle import pytest import torch +import torch.cuda.nvtx as nvtx import torch.nn as nn import torch.nn.functional as F from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, @@ -347,12 +348,21 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): mapping.rank = mpi_rank() with torch.device(f'cuda:{mapping.rank}'): - SEQ_LEN = 4 + ###################################################################### + # SEQ_LEN = 4 + # HIDDEN_SIZE = hidden_size + # INTERMEDIATE_SIZE = 640 + # SCALING_GROUP_SIZE = 32 + # NUM_EXPERTS = 4 + # TOP_K = 2 + ###################################################################### + SEQ_LEN = 16 HIDDEN_SIZE = hidden_size INTERMEDIATE_SIZE = 640 SCALING_GROUP_SIZE = 32 - NUM_EXPERTS = 4 - TOP_K = 2 + NUM_EXPERTS = 8 + TOP_K = 8 + ###################################################################### routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K) torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -465,8 +475,24 @@ def ref(): with torch.inference_mode(): ref_output = ref() + nvtx.range_push("autotune") with torch.inference_mode(), autotune(): fused_moe.forward(x, router_logits) + nvtx.range_pop() + + from tensorrt_llm._torch.custom_ops.torch_custom_ops import MoERunner + # Get the C++ FusedMoeRunner to query tactic descriptions + cpp_runner = next(iter(MoERunner.runner_dict.values())) + + cache = AutoTuner.get().profiling_cache.cache + for key, value in cache.items(): + custom_op, runner_cls, runner_id, shape_profile = key + runner_id_val, tactic, min_time = value + gemm_idx = 1 if "gemm1" in custom_op else 2 + desc = cpp_runner.get_tactic_desc(gemm_idx, tactic) + print(f"Op: {custom_op}, Runner: {runner_cls}, Shape: {shape_profile}") + print(f" -> Best tactic: {tactic}, Time: {min_time:.6f}ms") + print(f" -> {desc}") # Explicitly capture context for kernel testing with AutoTuner.get().capture() as all_tactics, torch.inference_mode(): From cd541bad3cf8df949be16b4b153664624f12be0c Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Tue, 24 Feb 2026 10:08:15 +0000 Subject: [PATCH 3/6] MoE CUTLASS backend INT4 x FP8 & MXFP4 x BF16 paths perf optimization Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../detail/collective/mixed_input_utils.hpp | 184 +++++-- ...a_gmma_rs_warpspecialized_mixed_input_.hpp | 471 ++++++++++++------ .../cutlass_kernels/cutlass_heuristic.cpp | 11 +- .../moe_gemm_tma_ws_mixed_input_launcher.inl | 5 + cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | 41 ++ .../_torch/modules/fused_moe/quantization.py | 71 ++- 6 files changed, 563 insertions(+), 220 deletions(-) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index bc3591ad3aed..204af0a01e9e 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -31,6 +31,8 @@ using namespace cute; typedef uint32_t __nv_fp4x8_storage_t; typedef uint32_t __nv_bf16x2_storage_t; +typedef uint32_t __nv_int4x8_storage_t; +typedef uint64_t __nv_fp8x8_storage_t; typedef cutlass::uint128_t __nv_bf16x8_storage_t; constexpr int int4_group_size = 128; @@ -50,53 +52,97 @@ inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code) return res; } -__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index) +__constant__ static __nv_fp8x4_storage_t HIGH_E4M3s_LUT_[2] = {0x03020100U, 0x03020100U}; +__constant__ static __nv_fp8x4_storage_t LOW_E4M3s_LUT_[2] = {0xFFFEFC00U, 0xFFFEFC00U}; + +__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_fp4_to_bf16(unsigned const index) { - const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654 - const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210 + + auto lane_id = threadIdx.x & 0x1; + __nv_fp8x4_storage_t h4b_lut = HIGH_E4M3s_LUT_[lane_id]; + __nv_fp8x4_storage_t l4b_lut = LOW_E4M3s_LUT_[lane_id]; __nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index); return lut_res; } -__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8) +__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved( + const __nv_fp4x8_storage_t fp4x8) { - __nv_bf16x8_storage_t bf16x8_raw = {0, 0}; + __nv_bf16x8_storage_t bf16x8_raw; __nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw); - unsigned zero_padding = 0x00000000U; + __nv_fp8x4_storage_t h_fp8x4_0to1_bits = (fp4x8 & 0xC0C0C0C0U) >> 6; // 7632 + __nv_fp8x4_storage_t l_fp8x4_0to1_bits = (fp4x8 & 0x0C0C0C0CU) >> 2; // 5410 unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U; unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U); - __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654 - __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210 + __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_fp4_to_bf16(h4b_em_fp4x4); // 7564 + __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_fp4_to_bf16(l4b_em_fp4x4); // 3120 - bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0 - bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2 - bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4 - bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6 + bf16x2_raw[0] = prmt(l_fp8x4_0to1_bits, l4b_2to9_bits, 0x5240U) << 6U; // 1 0 + bf16x2_raw[1] = prmt(h_fp8x4_0to1_bits, l4b_2to9_bits, 0x5341U) << 6U; // 3 2 - __nv_bf16x2_storage_t bf16x2_0to1_bits; + bf16x2_raw[2] = prmt(l_fp8x4_0to1_bits, h4b_2to9_bits, 0x7260U) << 6U; // 5 4 + bf16x2_raw[3] = prmt(h_fp8x4_0to1_bits, h4b_2to9_bits, 0x7361U) << 6U; // 7 6 - __nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1 - __nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0 + return bf16x8_raw; +} - bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0 - bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits; - bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2 - bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits; +// [ 0, 1, 2, 3] encoded as FP8 +__constant__ static uint32_t POS_E4M3s_REG1_[2] = {0x44403800, 0x44403800}; +// [ 4, 5, 6, 7] encoded as FP8 +__constant__ static uint32_t POS_E4M3s_REG2_[2] = {0x4E4C4A48, 0x4E4C4A48}; +// [-8, -7, -6, -5] encoded as FP8 +__constant__ static uint32_t NEG_E4M3s_REG1_[2] = {0xCACCCED0, 0xCACCCED0}; +// [-4, -3, -2, -1] encoded as FP8 +__constant__ static uint32_t NEG_E4M3s_REG2_[2] = {0xB8C0C4C8, 0xB8C0C4C8}; - h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5 - l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4 +__device__ __inline__ __nv_fp8x8_storage_t psx_cvt_lut_prmt_int4x8_to_fp8x8(const __nv_int4x8_storage_t int4x8) +{ + __nv_fp8x8_storage_t fp8x8_raw; + __nv_fp8x4_storage_t* fp8x4_raw = reinterpret_cast<__nv_fp8x4_storage_t*>(&fp8x8_raw); - bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4 - bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits; - bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6 - bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits; + // View the input as reg + uint32_t reg = reinterpret_cast(int4x8); - return bf16x8_raw; + // Determines if to get from the signed or unsigned candidates + uint32_t sign = (reg & 0x88888888) >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = (reg & 0x77777777); + + // Signed is OR'd with 0x32103210 to find the correct value in the LUT + const uint32_t final_prmt_base = 0x32103210; + + auto lane_id = threadIdx.x & 0x1; + uint32_t POS_E4M3s_REG1 = POS_E4M3s_REG1_[lane_id]; + uint32_t POS_E4M3s_REG2 = POS_E4M3s_REG2_[lane_id]; + uint32_t NEG_E4M3s_REG1 = NEG_E4M3s_REG1_[lane_id]; + uint32_t NEG_E4M3s_REG2 = NEG_E4M3s_REG2_[lane_id]; + + asm volatile( + "{\n" + " .reg .b32 pos_f8s, neg_f8s;\n" + " .reg .b32 lut1, sign1, prmt0, prmt1;\n" + " or.b32 prmt0, %4, %3;\n" + " prmt.b32 pos_f8s, %5, %6, %2;\n" + " prmt.b32 neg_f8s, %7, %8, %2;\n" + " prmt.b32 %0, pos_f8s, neg_f8s, prmt0;\n" + " shr.u32 lut1, %2, 16;\n" + " shr.u32 sign1, %3, 16;\n" + " or.b32 prmt1, %4, sign1;\n" + " prmt.b32 pos_f8s, %5, %6, lut1;\n" + " prmt.b32 neg_f8s, %7, %8, lut1;\n" + " prmt.b32 %1, pos_f8s, neg_f8s, prmt1;\n" + "}\n" + : "=r"(fp8x4_raw[0]), "=r"(fp8x4_raw[1]) + : "r"(lut_idx), "r"(sign), "r"(final_prmt_base), "r"(POS_E4M3s_REG1), "r"(POS_E4M3s_REG2), "r"(NEG_E4M3s_REG1), + "r"(NEG_E4M3s_REG2)); + + return fp8x8_raw; } template @@ -119,6 +165,7 @@ struct MixedGroupedGemmInputUtils static constexpr auto ModeHasScales = Collective::ModeHasScales; static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable; + static constexpr auto UseInt4ToFP8LookupTable = Collective::UseInt4ToFP8LookupTable; public: static constexpr auto elements_per_smem_scale() @@ -205,6 +252,59 @@ struct MixedGroupedGemmInputUtils } } + /// Utilities to copy A from smem to RF + template + CUTLASS_DEVICE static void copy_tensors_A(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, int k_block, int read_stage) + { + + if (k_block < size<2>(tCsA.shape())) + { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + } + } + + /// Utilities to copy Scales for A from smem to RF + template + CUTLASS_DEVICE static void copy_tensors_SFA(cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, int read_stage) + { + + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) + { + // nothing to do + } + else if constexpr (ModeHasScales) + { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block)); + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else + { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + /// Utilities to copy A and extra inputs from smem to RF template CUTLASS_DEVICE static void copy_tensors_MK(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA, @@ -212,7 +312,10 @@ struct MixedGroupedGemmInputUtils cute::tuple const& tiled_copy_and_views, int k_block, int read_stage) { - copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + if (k_block < size<2>(tCsA.shape())) + { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + } if (k_block == 0) { @@ -312,7 +415,6 @@ struct MixedGroupedGemmInputUtils } } - // The core converter uses a lookup table to converts i4 -> 8 bit value. template CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries @@ -330,7 +432,27 @@ struct MixedGroupedGemmInputUtils auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0); auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0); - dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_); + dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(src_); + } + + template + CUTLASS_DEVICE static void int4tofp8_lookup_table_convert( // Accept mutable temporaries + Tensor const& src, Tensor&& dst) + { + int4tofp8_lookup_table_convert(src, dst); + } + + template + CUTLASS_DEVICE static void int4tofp8_lookup_table_convert( + Tensor const& src, Tensor& dst) + { + + // View the input as reg + auto&& src_ = cute::recast<__nv_int4x8_storage_t>(src)(0); + auto&& dst_ = cute::recast<__nv_fp8x8_storage_t>(dst)(0); + + dst_ = psx_cvt_lut_prmt_int4x8_to_fp8x8(src_); } /// Utilities to dequantize A. @@ -535,6 +657,10 @@ struct MixedGroupedGemmInputUtils { fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i)); } + else if constexpr (UseInt4ToFP8LookupTable) + { + int4tofp8_lookup_table_convert(src_vm(_, i), dst_vm(_, i)); + } else { LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index 0ce601d5b086..c8f456b9ea1b 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -242,6 +242,8 @@ struct CollectiveMmaArrayMixedInput< } } + bool TensormapUpdateShapesStridesForAandScale = true; + public: static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale @@ -250,6 +252,8 @@ struct CollectiveMmaArrayMixedInput< = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cute::is_same_v && cute::is_same_v; + static constexpr bool UseInt4ToFP8LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale + && cute::is_same_v && cute::is_same_v; static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); @@ -899,6 +903,24 @@ struct CollectiveMmaArrayMixedInput< } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + CUTLASS_DEVICE float scale_convertor(T scale) + { + if constexpr (cute::is_same_v) + { + + cutlass::float_ue8m0_t scale_ue8m0 = scale; + + uint32_t temp = 0; + temp = (temp | *reinterpret_cast(&scale_ue8m0)) << 23; + return *reinterpret_cast(&temp); + } + else + { + return static_cast(scale); + } + } + /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template @@ -986,6 +1008,28 @@ struct CollectiveMmaArrayMixedInput< CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + using SmemCopyAtomA_LDSM = Copy_Atom; + + auto smem_tiled_copy_A_LDSM = make_tiled_copy_A(SmemCopyAtomA_LDSM{}, tiled_mma); + auto smem_thr_copy_A_LDSM = smem_tiled_copy_A_LDSM.get_thread_slice(thread_idx); + + Tensor sA_LDSM = recast(sA); + auto tCsA_LDSM = smem_thr_copy_A_LDSM.partition_S(sA_LDSM); + + using ABBitWidthRatio = Int / sizeof_bits_v>; + auto tCrA_load_LDSM_shape = replace<2>(tCrA_mma.shape(), size(get<2>(tCrA_mma.shape())) / ABBitWidthRatio{}); + Tensor tCrA_load_LDSM = make_fragment_like(tCrA_load_LDSM_shape); + Tensor tCrA_copy_view_LDSM = smem_thr_copy_A_LDSM.retile_D(tCrA_load_LDSM); // (CPY,CPY_M,CPY_K) + + auto ptr = recast_ptr(tCrA_load_LDSM.data()); + auto old_shape = tCrA_load_LDSM.shape(); + auto new_shape = make_shape(size<0>(old_shape), get<1>(old_shape), size<2>(old_shape) * ABBitWidthRatio{}); + Tensor tCrA_load_4b_packed = make_tensor(ptr, make_layout(new_shape)); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // // PIPELINED MAIN LOOP // @@ -1014,18 +1058,14 @@ struct CollectiveMmaArrayMixedInput< ++smem_pipe_read; barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - // copy smem->rmem for A operand - - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 0, read_stage); + Utils::copy_tensors_A(smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 0, read_stage); if (K_BLOCK_MAX > 1) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 1, read_stage); + Utils::copy_tensors_A(smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 1, read_stage); } // src: tCrA_load, dst: tCrA_mma - Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, 0); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL @@ -1045,51 +1085,56 @@ struct CollectiveMmaArrayMixedInput< intermediate_array[chunk_id]); tiled_mma.accumulate_ = GMMA::ScaleOut::One; - warpgroup_commit_batch(); + if (k_block == 0) + { + Utils::copy_tensors_SFA(partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + } if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, k_block + 2, read_stage); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { - Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, k_block + 1); } } - } - - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) - { - warpgroup_fence_operand(intermediate_array[chunk_id_]); + warpgroup_commit_batch(); - // Apply the group-wise scaling - // tCrS ((4, _2, _2), MMA_M, _1) - // accum ((2, _2, _2), MMA_M, _1) - auto tCrS = cute::get<1>(partitioned_extra_info); - for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) + if (chunk_id > 0) { - for (int m = 0; m < size<0, 1>(accum); m++) + warpgroup_wait<1>(); + + int chunk_id_ = chunk_id - 1; + warpgroup_fence_operand(intermediate_array[chunk_id_]); + + // Apply the group-wise scaling + // tCrS ((4, _2, _2), MMA_M, _1) + // accum ((2, _2, _2), MMA_M, _1) + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) { - for (int n = 0; n < size<0, 2>(accum); n++) + for (int m = 0; m < size<0, 1>(accum); m++) { - for (int e = 0; e < size<0, 0>(accum); e++) + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + for (int n = 0; n < size<0, 2>(accum); n++) { - auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); - auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); - - if (chunk_id_ == 0) - { - accum(accum_coord) = intermediate_array[chunk_id_](accum_coord) - * static_cast(tCrS(scale_coord)[0]); - } - else + for (int e = 0; e < size<0, 0>(accum); e++) { - accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), - static_cast(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + + if (chunk_id_ == 0) + { + accum(accum_coord) = intermediate_array[chunk_id_](accum_coord) + * scale_convertor(tCrS(scale_coord)[0]); + } + else + { + accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), + scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + } } } } @@ -1097,6 +1142,33 @@ struct CollectiveMmaArrayMixedInput< } } + warpgroup_wait<0>(); + + int chunk_id_ = NumChunksPerTileK - 1; + warpgroup_fence_operand(intermediate_array[chunk_id_]); + + // Apply the group-wise scaling + // tCrS ((4, _2, _2), MMA_M, _1) + // accum ((2, _2, _2), MMA_M, _1) + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) + { + for (int m = 0; m < size<0, 1>(accum); m++) + { + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + for (int n = 0; n < size<0, 2>(accum); n++) + { + for (int e = 0; e < size<0, 0>(accum); e++) + { + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + + accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), + scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + } + } + } + } + --k_tile_count; if (k_tile_count > 0) { @@ -1104,13 +1176,12 @@ struct CollectiveMmaArrayMixedInput< // the first mma. pipeline.consumer_wait(smem_pipe_read, barrier_token); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 0, smem_pipe_read.index()); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 0, smem_pipe_read.index()); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 1, smem_pipe_read.index()); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 1, smem_pipe_read.index()); - - Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, 0); } } @@ -1147,7 +1218,6 @@ struct CollectiveMmaArrayMixedInput< cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate_array[chunk_id]); tiled_mma.accumulate_ = GMMA::ScaleOut::One; - warpgroup_commit_batch(); if (k_block == K_BLOCK_MAX - 1) { @@ -1159,57 +1229,84 @@ struct CollectiveMmaArrayMixedInput< if (k_block == 0) { barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + Utils::copy_tensors_SFA(partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); } if (k_block == K_BLOCK_MAX - 1) { // The last k_block + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 0, smem_pipe_read.index()); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, 1, smem_pipe_read.index()); + + warpgroup_commit_batch(); warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) - { - warpgroup_fence_operand(intermediate_array[chunk_id_]); + warpgroup_fence_operand(intermediate_array[chunk_id]); - // Apply the group-wise scaling - auto tCrS = cute::get<1>(partitioned_extra_info); - for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) + // Apply the group-wise scaling + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) + { + for (int m = 0; m < size<0, 1>(accum); m++) { - for (int m = 0; m < size<0, 1>(accum); m++) + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + for (int n = 0; n < size<0, 2>(accum); n++) { - for (int n = 0; n < size<0, 2>(accum); n++) + for (int e = 0; e < size<0, 0>(accum); e++) { - for (int e = 0; e < size<0, 0>(accum); e++) - { - auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); - auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); - - accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), - static_cast(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); - } + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + + accum(accum_coord) = fma(intermediate_array[chunk_id](accum_coord), + scale_convertor(tCrS(scale_coord)[chunk_id]), accum(accum_coord)); } } } } - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // copy scales when passing k_block=0 - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 0, smem_pipe_read.index()); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, 1, smem_pipe_read.index()); - Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, 0); } else { if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, k_block + 2, read_stage); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, k_block + 2, read_stage); + } + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, k_block + 1); + } + } + + warpgroup_commit_batch(); + + if (chunk_id > 0) + { + warpgroup_wait<1>(); + + int chunk_id_ = chunk_id - 1; + warpgroup_fence_operand(intermediate_array[chunk_id_]); + + // Apply the group-wise scaling + auto tCrS = cute::get<1>(partitioned_extra_info); + for (int mma_m = 0; mma_m < size<1>(accum); mma_m++) + { + for (int m = 0; m < size<0, 1>(accum); m++) + { + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); + for (int n = 0; n < size<0, 2>(accum); n++) + { + for (int e = 0; e < size<0, 0>(accum); e++) + { + auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); + + accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), + scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + } + } } - Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); } } } @@ -1234,7 +1331,11 @@ struct CollectiveMmaArrayMixedInput< // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), tCrB(_, _, k_block, read_stage), intermediate); tiled_mma.accumulate_ = GMMA::ScaleOut::One; - warpgroup_commit_batch(); + + if (k_block == 0) + { + Utils::copy_tensors_SFA(partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + } if (k_block == K_BLOCK_MAX - 1) { @@ -1245,18 +1346,19 @@ struct CollectiveMmaArrayMixedInput< if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, - copy_partitions_extra_info, k_block + 2, read_stage); + Utils::copy_tensors_A( + smem_tiled_copy_A_LDSM, tCsA_LDSM, tCrA_copy_view_LDSM, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { - Utils::convert_A_kblock(tCrA_load, tCrA_mma, k_block + 1); + Utils::convert_A_kblock(tCrA_load_4b_packed, tCrA_mma, k_block + 1); } if ((k_block + 1) % NumMMAsPerChunk == 0) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_commit_batch(); warpgroup_wait<0>(); warpgroup_fence_operand(intermediate); @@ -1266,16 +1368,16 @@ struct CollectiveMmaArrayMixedInput< { for (int m = 0; m < size<0, 1>(accum); m++) { + auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); for (int n = 0; n < size<0, 2>(accum); n++) { for (int e = 0; e < size<0, 0>(accum); e++) { auto accum_coord = make_coord(make_tuple(e, m, n), mma_m, 0); - auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); int scale_idx = k_block / NumMMAsPerChunk; accum(accum_coord) = fma(intermediate(accum_coord), - static_cast(tCrS(scale_coord)[scale_idx]), accum(accum_coord)); + scale_convertor(tCrS(scale_coord)[scale_idx]), accum(accum_coord)); } } } @@ -1295,9 +1397,6 @@ struct CollectiveMmaArrayMixedInput< smem_pipe_release.advance(k_tile_count); - // Wait on all GMMAs to complete - // warpgroup_wait<0>(); - for (int count = 0; count < prologue_mma_count; ++count) { pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -1379,29 +1478,53 @@ struct CollectiveMmaArrayMixedInput< } // Replace address for the global tensor (to be done by single thread) - CUTLASS_DEVICE - void tensormaps_replace_global_address( - TensorMapStorage& shared_tensormaps, Params const& mainloop_params, int32_t next_batch) + template + CUTLASS_DEVICE void tensormaps_replace_global_address(TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, cute::tuple const& input_tensormaps, int32_t next_batch) { // Replacing global_address for the next batch - cute::tma_descriptor_replace_addr_in_shared_mem( - shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]); cute::tma_descriptor_replace_addr_in_shared_mem( shared_tensormaps.smem_tensormap_B, mainloop_params.ptr_B[next_batch]); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) - { - cute::tma_descriptor_replace_addr_in_shared_mem( - shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + + if (TensormapUpdateShapesStridesForAandScale) { cute::tma_descriptor_replace_addr_in_shared_mem( - shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]); + shared_tensormaps.smem_tensormap_A, mainloop_params.ptr_A[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_scale, mainloop_params.ptr_S[next_batch]); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + cute::tma_descriptor_replace_addr_in_shared_mem( + shared_tensormaps.smem_tensormap_zero, mainloop_params.ptr_Z[next_batch]); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_address."); + } } - else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in tensormaps_replace_global_address."); + cute::tma_descriptor_replace_addr_in_global_mem( + get<0>(input_tensormaps), mainloop_params.ptr_A[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + cute::tma_descriptor_replace_addr_in_global_mem( + get<2>(input_tensormaps), mainloop_params.ptr_S[next_batch]); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + cute::tma_descriptor_replace_addr_in_global_mem( + get<3>(input_tensormaps), mainloop_params.ptr_Z[next_batch]); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_address."); + } } } @@ -1425,80 +1548,82 @@ struct CollectiveMmaArrayMixedInput< cute::array prob_shape_zero = {1, 1, 1, 1, 1}; cute::array prob_stride_zero = {0, 0, 0, 0, 0}; - SwappedElementA const* ptr_A = nullptr; - Tensor tensor_a = make_tensor( - ptr_A, detail::get_gmem_layout(make_shape(M, K, Int<1>{}), mainloop_params.ptr_dA[next_group])); - SwappedElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor( ptr_B, detail::get_gmem_layout(make_shape(N, K, Int<1>{}), mainloop_params.ptr_dB[next_group])); - - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) - { - NonVoidElementScale const* ptr_S = nullptr; - // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / ScalingGroupSize; - Tensor tensor_scale = make_tensor( - detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); - cute::detail::fill_tma_gmem_shape_stride( - mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) - { - ElementZero const* ptr_Z = nullptr; - // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / ScalingGroupSize; - Tensor tensor_zero = make_tensor( - detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); - cute::detail::fill_tma_gmem_shape_stride( - mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero); - } - else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) - { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); - } - - // Convert strides to byte strides - for (uint64_t& stride : prob_stride_A) - { - stride = (stride * sizeof_bits_v) / 8; - } for (uint64_t& stride : prob_stride_B) { stride = (stride * sizeof_bits_v) / 8; } - for (uint64_t& stride : prob_stride_scale) - { - stride = (stride * sizeof_bits_v) / 8; - } - for (uint64_t& stride : prob_stride_zero) - { - stride = (stride * sizeof_bits_v) / 8; - } - cute::tma_descriptor_replace_dims_strides_in_shared_mem( - shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A); cute::tma_descriptor_replace_dims_strides_in_shared_mem( shared_tensormaps.smem_tensormap_B, prob_shape_B, prob_stride_B); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) - { - cute::tma_descriptor_replace_dims_strides_in_shared_mem( - shared_tensormaps.smem_tensormap_scale, prob_shape_scale, prob_stride_scale); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + if (TensormapUpdateShapesStridesForAandScale) { + + SwappedElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor( + ptr_A, detail::get_gmem_layout(make_shape(M, K, Int<1>{}), mainloop_params.ptr_dA[next_group])); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + NonVoidElementScale const* ptr_S = nullptr; + // auto scale_k = K / mainloop_params.chunk_size; + auto scale_k = K / ScalingGroupSize; + Tensor tensor_scale = make_tensor( + detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride( + mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + ElementZero const* ptr_Z = nullptr; + // auto scale_k = K / mainloop_params.chunk_size; + auto scale_k = K / ScalingGroupSize; + Tensor tensor_zero = make_tensor( + detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride( + mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) + { + stride = (stride * sizeof_bits_v) / 8; + } cute::tma_descriptor_replace_dims_strides_in_shared_mem( - shared_tensormaps.smem_tensormap_zero, prob_shape_zero, prob_stride_zero); - } - else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) - { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + shared_tensormaps.smem_tensormap_A, prob_shape_A, prob_stride_A); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + for (uint64_t& stride : prob_stride_scale) + { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_scale, prob_shape_scale, prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + for (uint64_t& stride : prob_stride_zero) + { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem( + shared_tensormaps.smem_tensormap_zero, prob_shape_zero, prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } } } @@ -1509,7 +1634,7 @@ struct CollectiveMmaArrayMixedInput< if (cute::elect_one_sync()) { // Replacing global_address for the next batch - tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, input_tensormaps, next_batch); if constexpr (IsGroupedGemmKernel) { @@ -1524,26 +1649,40 @@ struct CollectiveMmaArrayMixedInput< CUTLASS_DEVICE void tensormaps_cp_fence_release( TensorMapStorage& shared_tensormaps, cute::tuple const& input_tensormaps) { + + // [None][fix] Fix W4A8 MoE kernel issue + // https://github.com/NVIDIA/TensorRT-LLM/pull/7072 if (cute::elect_one_sync()) { cute::tma_desc_commit_group(); cute::tma_desc_wait_group(); } + // Entire warp must do this (i.e. it's aligned) - tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); - if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) - { - tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); - } - else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + + if (TensormapUpdateShapesStridesForAandScale) { - tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + TensormapUpdateShapesStridesForAandScale = false; + + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) + { + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) + { + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in tensormaps_cp_fence_release."); + } } - else if constexpr (KernelConversionMode != ConversionMode::DirectConvert) + else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in tensormaps_cp_fence_release."); + tma_descriptor_fence_release(); } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 50794769b5ec..e804f170ac6b 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -305,11 +305,20 @@ std::vector get_candidate_configs_sm90(CutlassGemmConfig::Can if (has_w4afp8) { bool const has_coop_supported = sm90_supports_coop(tile_config); - std::set mainloop_schedules{MainloopScheduleType::PINGPONG}; + std::set mainloop_schedules; if (has_coop_supported) { + // Due to the limitation on the number of registers on SM, + // cooperative scheduler does not support CtaShape128x128x128B. + if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B) + continue; + mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE); } + else + { + mainloop_schedules.insert(MainloopScheduleType::PINGPONG); + } auto const epilogue_schedule = EpilogueScheduleType::AUTO; for (auto const& mainloop_schedule : mainloop_schedules) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index f37920dcf73c..d411cfd96795 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -27,6 +27,7 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cutlass/util/packed_stride.hpp" @@ -144,6 +145,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput>::RasterOrderOptions; using EpilogueFusionOp = cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter< typename cutlass::layout::LayoutTranspose::type, ElementFinalOutput, ElementAccumulator, ElementBias, @@ -257,6 +259,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.int4_groupwise_params.stride_s_a), group_size}, epilogue_args, hw_info}; + arguments.scheduler.max_swizzle_size = 2; + arguments.scheduler.raster_order = RasterOrderOptions::Heuristic; + assert(group_size == int(inputs.groupwise_quant_group_size)); if (workspace_size != nullptr) { diff --git a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp index 89c3312b9be3..a0cf0ee45500 100644 --- a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp +++ b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp @@ -17,6 +17,7 @@ #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaBf16Wrapper.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h" #include "tensorrt_llm/thop/thUtils.h" #if defined(TORCH_VERSION_MAJOR) \ @@ -401,6 +402,42 @@ Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_si return dequant_weight; } +Tensor interleave_4bit_weights_for_Hopper_mixed_gemm(Tensor weight, int64_t weight_quant_type) +{ + // weight_quant_type: + // 0 for int4 + // 1 for fp4 + TORCH_CHECK(weight_quant_type == 0 || weight_quant_type == 1, "Invalid weight quant type"); + + // weight (n, k / 2) + int const n = weight.size(0); + int const k = weight.size(1) * 2; + + CHECK_TH_CUDA(weight); + CHECK_CONTIGUOUS(weight); + + TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); + TORCH_CHECK( + weight.dtype() == torch::kInt8 || weight.dtype() == torch::kUInt8, "Weight must be a packed int8/uint8 tensor"); + + Tensor weight_interleaved + = torch::empty({n, k / 2}, torch::dtype(torch::kUInt8).device(torch::kCUDA).requires_grad(false)); + + uint8_t* weight_ptr = get_ptr(weight); + uint8_t* weight_interleaved_ptr = get_ptr(weight_interleaved); + + if (weight_quant_type == 0) + { + interleave_int4_weights_for_Hopper_mixed_gemm(weight_ptr, weight_interleaved_ptr, n, k); + } + else if (weight_quant_type == 1) + { + interleave_fp4_weights_for_Hopper_mixed_gemm(weight_ptr, weight_interleaved_ptr, n, k); + } + + return weight_interleaved; +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -438,3 +475,7 @@ static auto subbyte_transpose static auto mxfp4_dequantize_unswizzled = torch::RegisterOperators( "trtllm::mxfp4_dequantize_unswizzled", &tensorrt_llm::torch_ext::mxfp4_dequantize_unswizzled); + +static auto interleave_4bit_weights_for_Hopper_mixed_gemm + = torch::RegisterOperators("trtllm::interleave_4bit_weights_for_Hopper_mixed_gemm", + &tensorrt_llm::torch_ext::interleave_4bit_weights_for_Hopper_mixed_gemm); diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 2e9b43550b6c..aa2c479db360 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1324,18 +1324,24 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, torch.float8_e4m3fn, 89).view(dst_w3_w1_weight.shape) # SM90 ModelOpt quantized weights - elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one - # Transpose: [K, (N//2)*I4x2] - transposed = w31_weight_shard.cpu().T.contiguous() - # Unpack: [K, N*I8] - unpacked = unpacker(transposed.view(torch.int8)) - # Transpose: [N, K*I8] - transposed = unpacked.T.contiguous() - # Pack: [N, (K//2)*I4x2] - w31_weight_shard = packer(transposed) - elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: - pass + elif module.sm_version == 90: + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one + # Transpose: [K, (N//2)*I4x2] + transposed = w31_weight_shard.cpu().T.contiguous() + # Unpack: [K, N*I8] + unpacked = unpacker(transposed.view(torch.int8)) + # Transpose: [N, K*I8] + transposed = unpacked.T.contiguous() + # Pack: [N, (K//2)*I4x2] + w31_weight_shard = packer(transposed) + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + pass + + if w31_weight_shard.ndim == 2: + w31_weight_shard = w31_weight_shard.cuda() + w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( + w31_weight_shard, 0) else: raise NotImplementedError( f"Unsupported configuration: SM{module.sm_version} and {module.weight_loading_mode}." @@ -1370,18 +1376,24 @@ def load_expert_w2_weight(self, module: torch.nn.Module, torch.float8_e4m3fn, 89).view(dst_w2_weight.shape) # SM90 ModelOpt quantized weights - elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one - # Transpose: [K, (N//2)*I4x2] - transposed = w2_weight_shard.cpu().T.contiguous() - # Unpack: [K, N*I8] - unpacked = unpacker(transposed.view(torch.int8)) - # Transpose: [N, K*I8] - transposed = unpacked.T.contiguous() - # Pack: [N, (K//2)*I4x2] - w2_weight_shard = packer(transposed) - elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: - pass + elif module.sm_version == 90: + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + # Original: [(N//2)*I4x2, K] which is two int4 elts in output dim packed into one + # Transpose: [K, (N//2)*I4x2] + transposed = w2_weight_shard.cpu().T.contiguous() + # Unpack: [K, N*I8] + unpacked = unpacker(transposed.view(torch.int8)) + # Transpose: [N, K*I8] + transposed = unpacked.T.contiguous() + # Pack: [N, (K//2)*I4x2] + w2_weight_shard = packer(transposed) + elif module.sm_version == 90 and module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM: + pass + + if w2_weight_shard.ndim == 2: + w2_weight_shard = w2_weight_shard.cuda() + w2_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( + w2_weight_shard, 0) else: raise NotImplementedError( f"Unsupported configuration: SM{module.sm_version} and {module.weight_loading_mode}." @@ -1704,9 +1716,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[ 0] if w3_weight_shard.ndim == 2: + # [intermediate_size, hidden_size] pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1] pad_shape = (0, pad_size_hidden, 0, pad_size_inter) elif w3_weight_shard.ndim == 1: + # [intermediate_size] pad_shape = (0, pad_size_inter) else: raise NotImplementedError( @@ -1718,6 +1732,10 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + if w31_weight_shard.ndim == 2: + w31_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( + w31_weight_shard, 1) + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), non_blocking=True) @@ -1747,6 +1765,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module, f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}") w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape) + + if w2_weight_shard.ndim == 2: + w2_weight_shard = torch.ops.trtllm.interleave_4bit_weights_for_Hopper_mixed_gemm( + w2_weight_shard, 1) + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True) From c7917f177e738516ffe20431301faf6731721c86 Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Tue, 24 Feb 2026 10:09:13 +0000 Subject: [PATCH 4/6] Remove some compiling warnings Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../kernels/cutlass_kernels/moe_gemm/moe_kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 2447273752f0..7460e9279b07 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2154,8 +2154,8 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f; // Some globals for FP4 - float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; - int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0; + [[maybe_unused]] float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; + [[maybe_unused]] int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0; size_t bias_offset = 0; if (bias_ptr) @@ -2399,11 +2399,11 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr; } }; - auto NVFP4 = tensorrt_llm::common::ConstExprWrapper{}; - auto MXFPX = tensorrt_llm::common::ConstExprWrapper{}; - auto NONE = tensorrt_llm::common::ConstExprWrapper{}; #ifdef ENABLE_FP4 if constexpr (std::is_same_v) @@ -4613,7 +4613,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4); bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16) && mWType == nvinfer1::DataType::kUINT8); - bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; + [[maybe_unused]] bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE; bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_wfp4a16 || mGemmToProfile != GemmToProfile::GEMM_2; From 62d9327d16bb6522afdb7c44e7da0c9e549baf3e Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Tue, 24 Feb 2026 13:56:00 +0000 Subject: [PATCH 5/6] Update the copy of test_fused_moe.py Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../_torch/modules/test_fused_moe_jiangs.py | 592 +++++++++--------- 1 file changed, 287 insertions(+), 305 deletions(-) diff --git a/tests/unittest/_torch/modules/test_fused_moe_jiangs.py b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py index b94916a493cf..383f8da0e585 100644 --- a/tests/unittest/_torch/modules/test_fused_moe_jiangs.py +++ b/tests/unittest/_torch/modules/test_fused_moe_jiangs.py @@ -1,10 +1,6 @@ -import os import pickle import sys -from contextlib import contextmanager -from itertools import product from typing import Dict, List, Optional -from unittest import mock import _torch.helpers import cloudpickle @@ -13,36 +9,35 @@ import torch.cuda.nvtx as nvtx import torch.nn as nn import torch.nn.functional as F -from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8, - per_block_cast_to_fp8_e8m0, - per_token_cast_to_fp8_e8m0) from mpi4py import MPI -from mpi4py.futures import MPIPoolExecutor from transformers.configuration_utils import PretrainedConfig -from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce, - skip_neither_ada_nor_hopper_unittest, skip_no_hopper, - skip_pre_blackwell, skip_pre_hopper) +from utils.util import ( + check_accuracy, + skip_blackwell, + skip_blackwell_geforce, + skip_neither_ada_nor_hopper_unittest, + skip_no_hopper, + skip_pre_blackwell, + skip_pre_hopper, +) from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ - CuteDslFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ - DeepGemmFusedMoE -from tensorrt_llm._torch.modules.fused_moe.interface import ( - AlltoallMethodType, MoEWeightLoadingMode) +from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode # isort and yapf will fight against each other here, so we disable isort # isort: off from tensorrt_llm._torch.modules.fused_moe import ( - BaseMoeRoutingMethod, CutlassFusedMoE, TRTLLMGenFusedMoE, - DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, TritonFusedMoE, - create_moe, WideEPMoE) -from tensorrt_llm._torch.modules.fused_moe.quantization import \ - NVFP4CutlassFusedMoEMethod + BaseMoeRoutingMethod, + CutlassFusedMoE, + RenormalizeMoeRoutingMethod, + TritonFusedMoE, + create_moe, +) + # isort: on from tensorrt_llm._torch.modules.gated_mlp import GatedMLP -from tensorrt_llm._utils import get_sm_version, mpi_rank +from tensorrt_llm._utils import mpi_rank from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig @@ -58,13 +53,13 @@ @skip_neither_ada_nor_hopper_unittest @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( - "weight_loading_mode", - [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM]) + "weight_loading_mode", [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.W4A8_CUSTOM] +) def test_fused_moe_w4afp8(dtype, weight_loading_mode): mapping = Mapping() mapping.rank = mpi_rank() - with torch.device(f'cuda:{mapping.rank}'): + with torch.device(f"cuda:{mapping.rank}"): SEQ_LEN = 4 HIDDEN_SIZE = 768 INTERMEDIATE_SIZE = 640 @@ -75,24 +70,20 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode): torch.manual_seed(0) torch.cuda.manual_seed(0) x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") - router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), - dtype=dtype, - device="cuda") + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype, device="cuda") affine_coeff = 0.005 lut = { - "weight": - "weight", - "weight_scale": - ("weight_scale_inv" if weight_loading_mode - == MoEWeightLoadingMode.W4A8_CUSTOM else "weight_scale"), - "weight_scale_2": - "weight_scale_2", - "pre_quant_scale": - "pre_quant_scale", - "input_scale": - "input_scale", + "weight": "weight", + "weight_scale": ( + "weight_scale_inv" + if weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM + else "weight_scale" + ), + "weight_scale_2": "weight_scale_2", + "pre_quant_scale": "pre_quant_scale", + "input_scale": "input_scale", } weights = {} @@ -110,47 +101,52 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode): w3_shape = (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2) # The weights in int4 precision. - w1_weight = torch.randint(-128, 127, w1_shape, - dtype=torch.int8).cuda() - w2_weight = torch.randint(-128, 127, w2_shape, - dtype=torch.int8).cuda() - w3_weight = torch.randint(-128, 127, w3_shape, - dtype=torch.int8).cuda() + w1_weight = torch.randint(-128, 127, w1_shape, dtype=torch.int8).cuda() + w2_weight = torch.randint(-128, 127, w2_shape, dtype=torch.int8).cuda() + w3_weight = torch.randint(-128, 127, w3_shape, dtype=torch.int8).cuda() # The pre-quant scale to be multiplied with the input activation. # Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles # non-uniform pre-quant scaling factors correctly - w1_pre_quant_scale = torch.rand( - HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 - w2_pre_quant_scale = torch.rand( - INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 - w3_pre_quant_scale = torch.rand( - HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + w1_pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + w2_pre_quant_scale = ( + torch.rand(INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 + ) + w3_pre_quant_scale = torch.rand(HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95 # The weight scale to dequantize int4 weights (by multiplication). - w1_scale = torch.randn( - (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), - dtype=dtype, - device="cuda") * affine_coeff - w2_scale = torch.randn( - (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), - dtype=dtype, - device="cuda") * affine_coeff - w3_scale = torch.randn( - (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), - dtype=dtype, - device="cuda") * affine_coeff + w1_scale = ( + torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda", + ) + * affine_coeff + ) + w2_scale = ( + torch.randn( + (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda", + ) + * affine_coeff + ) + w3_scale = ( + torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + dtype=dtype, + device="cuda", + ) + * affine_coeff + ) # The input scale to quantize the input activation (by division). - w1_input_scale = torch.randn(1, dtype=torch.float32, - device="cuda") * 0.2 + w1_input_scale = torch.randn(1, dtype=torch.float32, device="cuda") * 0.2 w2_input_scale = w1_input_scale w3_input_scale = w1_input_scale # The weight scale 2 to quantize the dequantized weights (by division). - w1_weight_scale_2 = torch.ones([1], - dtype=torch.float32, - device="cuda") + w1_weight_scale_2 = torch.ones([1], dtype=torch.float32, device="cuda") w2_weight_scale_2 = w1_weight_scale_2 w3_weight_scale_2 = w1_weight_scale_2 @@ -164,18 +160,12 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode): weights[f"{expert_id}.w1.{lut['weight_scale']}"] = w1_scale weights[f"{expert_id}.w2.{lut['weight_scale']}"] = w2_scale weights[f"{expert_id}.w3.{lut['weight_scale']}"] = w3_scale - weights[ - f"{expert_id}.w1.{lut['pre_quant_scale']}"] = w1_pre_quant_scale - weights[ - f"{expert_id}.w2.{lut['pre_quant_scale']}"] = w2_pre_quant_scale - weights[ - f"{expert_id}.w3.{lut['pre_quant_scale']}"] = w3_pre_quant_scale - weights[ - f"{expert_id}.w1.{lut['weight_scale_2']}"] = w1_weight_scale_2 - weights[ - f"{expert_id}.w2.{lut['weight_scale_2']}"] = w2_weight_scale_2 - weights[ - f"{expert_id}.w3.{lut['weight_scale_2']}"] = w3_weight_scale_2 + weights[f"{expert_id}.w1.{lut['pre_quant_scale']}"] = w1_pre_quant_scale + weights[f"{expert_id}.w2.{lut['pre_quant_scale']}"] = w2_pre_quant_scale + weights[f"{expert_id}.w3.{lut['pre_quant_scale']}"] = w3_pre_quant_scale + weights[f"{expert_id}.w1.{lut['weight_scale_2']}"] = w1_weight_scale_2 + weights[f"{expert_id}.w2.{lut['weight_scale_2']}"] = w2_weight_scale_2 + weights[f"{expert_id}.w3.{lut['weight_scale_2']}"] = w3_weight_scale_2 quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ) fused_moe = CutlassFusedMoE( @@ -186,7 +176,8 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode): dtype=dtype, reduce_results=False, model_config=ModelConfig(quant_config=quant_config), - weight_loading_mode=weight_loading_mode) + weight_loading_mode=weight_loading_mode, + ) fused_moe.load_weights([weights]) fused_moe.cuda() @@ -199,8 +190,7 @@ def ref(): act = x[activated_tokens, :] if act.shape[0] == 0: continue - final_scale = (final_scales * - mask).sum(1)[activated_tokens].unsqueeze(1) + final_scale = (final_scales * mask).sum(1)[activated_tokens].unsqueeze(1) # weights def unpack_weights(weight: torch.Tensor) -> torch.Tensor: @@ -216,12 +206,9 @@ def unpack_weights(weight: torch.Tensor) -> torch.Tensor: w3_w1 = torch.cat([w3, w1], dim=-1) # weight_scale - s1 = weights[f"{e_idx}.w1.{lut['weight_scale']}"].T.contiguous( - ).cuda() - s2 = weights[f"{e_idx}.w2.{lut['weight_scale']}"].T.contiguous( - ).cuda() - s3 = weights[f"{e_idx}.w3.{lut['weight_scale']}"].T.contiguous( - ).cuda() + s1 = weights[f"{e_idx}.w1.{lut['weight_scale']}"].T.contiguous().cuda() + s2 = weights[f"{e_idx}.w2.{lut['weight_scale']}"].T.contiguous().cuda() + s3 = weights[f"{e_idx}.w3.{lut['weight_scale']}"].T.contiguous().cuda() s3_s1 = torch.cat([s3, s1], dim=-1) # input_scale @@ -233,15 +220,9 @@ def unpack_weights(weight: torch.Tensor) -> torch.Tensor: # pre_quant_scale a1 = a2 = a3 = a1_a3 = None if weight_loading_mode == MoEWeightLoadingMode.VANILLA: - a1 = weights[ - f"{e_idx}.w1.{lut['pre_quant_scale']}"].T.contiguous( - ).cuda() - a2 = weights[ - f"{e_idx}.w2.{lut['pre_quant_scale']}"].T.contiguous( - ).cuda() - a3 = weights[ - f"{e_idx}.w3.{lut['pre_quant_scale']}"].T.contiguous( - ).cuda() + a1 = weights[f"{e_idx}.w1.{lut['pre_quant_scale']}"].T.contiguous().cuda() + a2 = weights[f"{e_idx}.w2.{lut['pre_quant_scale']}"].T.contiguous().cuda() + a3 = weights[f"{e_idx}.w3.{lut['pre_quant_scale']}"].T.contiguous().cuda() a1_a3 = torch.max(a1, a3) # weight_scale_2 @@ -263,10 +244,14 @@ def process_layer( ): if pre_quant_scale is not None: act = act * pre_quant_scale - act = (torch.clamp((act / input_scale), -448.0, - 448.0).to(torch.float8_e4m3fn).to(dtype)) - weight = (weight.float() * weight_scale.repeat_interleave( - 128, dim=0).float()).to(dtype) + act = ( + torch.clamp((act / input_scale), -448.0, 448.0) + .to(torch.float8_e4m3fn) + .to(dtype) + ) + weight = ( + weight.float() * weight_scale.repeat_interleave(128, dim=0).float() + ).to(dtype) if weight_scale_2 is not None: weight /= weight_scale_2 output = torch.matmul(act, weight) * input_scale @@ -287,15 +272,9 @@ def process_layer( fc1 = fc1 * torch.nn.functional.silu(gate) # fc2 - fc2 = process_layer(fc1, - w2, - s2, - p2, - pre_quant_scale=a2, - weight_scale_2=q2) - - results[activated_tokens, :] += (fc2 * final_scale).to( - results.dtype) + fc2 = process_layer(fc1, w2, s2, p2, pre_quant_scale=a2, weight_scale_2=q2) + + results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype) return results AutoTuner.get().clear_cache() @@ -316,10 +295,7 @@ def process_layer( # assert that result does not contain NaN or is all 0s assert not torch.isnan(output).any(), "output contains NaN" assert torch.nonzero(output).numel() > 0, "output is empty" - torch.testing.assert_close(output, - ref_output, - rtol=1e-2, - atol=0.1) + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) torch.cuda.synchronize() assert not torch.isnan(ref_output).any(), "ref_output contains NaN" @@ -336,18 +312,15 @@ def process_layer( "moe_backend", [ # smVersion - pytest.param("TRTLLM", - marks=[skip_blackwell_geforce, skip_pre_blackwell]), - pytest.param( - "CUTLASS", - marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]), + pytest.param("TRTLLM", marks=[skip_blackwell_geforce, skip_pre_blackwell]), + pytest.param("CUTLASS", marks=[skip_pre_hopper, skip_blackwell, skip_blackwell_geforce]), ], ) def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): mapping = Mapping() mapping.rank = mpi_rank() - with torch.device(f'cuda:{mapping.rank}'): + with torch.device(f"cuda:{mapping.rank}"): ###################################################################### # SEQ_LEN = 4 # HIDDEN_SIZE = hidden_size @@ -358,7 +331,7 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): ###################################################################### SEQ_LEN = 16 HIDDEN_SIZE = hidden_size - INTERMEDIATE_SIZE = 640 + INTERMEDIATE_SIZE = 1024 SCALING_GROUP_SIZE = 32 NUM_EXPERTS = 8 TOP_K = 8 @@ -371,37 +344,37 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): weights = {} for expert_id in range(NUM_EXPERTS): - w1_weight = torch.randint(0, - 256, - (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), - dtype=torch.uint8, - device='cuda') - w2_weight = torch.randint(0, - 256, - (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2), - dtype=torch.uint8, - device='cuda') - w3_weight = torch.randint(0, - 256, - (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), - dtype=torch.uint8, - device='cuda') + w1_weight = torch.randint( + 0, 256, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), dtype=torch.uint8, device="cuda" + ) + w2_weight = torch.randint( + 0, 256, (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2), dtype=torch.uint8, device="cuda" + ) + w3_weight = torch.randint( + 0, 256, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2), dtype=torch.uint8, device="cuda" + ) w1_scale = torch.randint( 118, - 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + 123, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), dtype=torch.uint8, - device='cuda') + device="cuda", + ) w2_scale = torch.randint( 118, - 123, (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), + 123, + (HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE), dtype=torch.uint8, - device='cuda') + device="cuda", + ) w3_scale = torch.randint( 118, - 123, (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), + 123, + (INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE), dtype=torch.uint8, - device='cuda') + device="cuda", + ) weights[f"{expert_id}.w1.weight"] = w1_weight weights[f"{expert_id}.w2.weight"] = w2_weight @@ -424,12 +397,15 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend): pretrained_config.intermediate_size = INTERMEDIATE_SIZE pretrained_config.torch_dtype = dtype - fused_moe = create_moe(routing_method=routing_method, - reduce_results=False, - model_config=ModelConfig( - pretrained_config=pretrained_config, - quant_config=quant_config, - moe_backend=moe_backend)) + fused_moe = create_moe( + routing_method=routing_method, + reduce_results=False, + model_config=ModelConfig( + pretrained_config=pretrained_config, + quant_config=quant_config, + moe_backend=moe_backend, + ), + ) fused_moe.load_weights([weights]) fused_moe.cuda() @@ -443,8 +419,7 @@ def ref(): act = x[activated_tokens, :] if act.shape[0] == 0: continue - final_scale = (final_scales * - mask).sum(1)[activated_tokens].unsqueeze(1) + final_scale = (final_scales * mask).sum(1)[activated_tokens].unsqueeze(1) # weights and scales w1 = weights[f"{e_idx}.w1.weight"] @@ -455,20 +430,28 @@ def ref(): s3 = weights[f"{e_idx}.w3.weight_scale_inv"] # converted weights - w1 = unpacker(w1.cpu(), s1.cpu(), SCALING_GROUP_SIZE).to( - dtype=x.dtype, device=x.device).T.contiguous() - w2 = unpacker(w2.cpu(), s2.cpu(), SCALING_GROUP_SIZE).to( - dtype=x.dtype, device=x.device).T.contiguous() - w3 = unpacker(w3.cpu(), s3.cpu(), SCALING_GROUP_SIZE).to( - dtype=x.dtype, device=x.device).T.contiguous() + w1 = ( + unpacker(w1.cpu(), s1.cpu(), SCALING_GROUP_SIZE) + .to(dtype=x.dtype, device=x.device) + .T.contiguous() + ) + w2 = ( + unpacker(w2.cpu(), s2.cpu(), SCALING_GROUP_SIZE) + .to(dtype=x.dtype, device=x.device) + .T.contiguous() + ) + w3 = ( + unpacker(w3.cpu(), s3.cpu(), SCALING_GROUP_SIZE) + .to(dtype=x.dtype, device=x.device) + .T.contiguous() + ) w3_w1 = torch.cat([w3, w1], dim=-1) fc1 = torch.matmul(act, w3_w1) fc1, gate = fc1.chunk(2, dim=-1) fc1 = fc1 * torch.nn.functional.silu(gate) fc2 = torch.matmul(fc1, w2) - results[activated_tokens, :] += (fc2 * final_scale).to( - results.dtype) + results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype) return results AutoTuner.get().clear_cache() @@ -481,32 +464,58 @@ def ref(): nvtx.range_pop() from tensorrt_llm._torch.custom_ops.torch_custom_ops import MoERunner + # Get the C++ FusedMoeRunner to query tactic descriptions cpp_runner = next(iter(MoERunner.runner_dict.values())) - cache = AutoTuner.get().profiling_cache.cache - for key, value in cache.items(): - custom_op, runner_cls, runner_id, shape_profile = key - runner_id_val, tactic, min_time = value - gemm_idx = 1 if "gemm1" in custom_op else 2 - desc = cpp_runner.get_tactic_desc(gemm_idx, tactic) - print(f"Op: {custom_op}, Runner: {runner_cls}, Shape: {shape_profile}") - print(f" -> Best tactic: {tactic}, Time: {min_time:.6f}ms") - print(f" -> {desc}") + # cache = AutoTuner.get().profiling_cache.cache + # for key, value in cache.items(): + # custom_op, runner_cls, runner_id, shape_profile = key + # runner_id_val, tactic, min_time = value + # gemm_idx = 1 if "gemm1" in custom_op else 2 + # desc = cpp_runner.get_tactic_desc(gemm_idx, tactic) + # print(f"Op: {custom_op}, Runner: {runner_cls}, Shape: {shape_profile}") + # print(f" -> Best tactic: {tactic}, Time: {min_time:.6f}ms") + # print(f" -> {desc}") # Explicitly capture context for kernel testing with AutoTuner.get().capture() as all_tactics, torch.inference_mode(): output = fused_moe.forward(x, router_logits) + # tactics_list = list(all_tactics) + # for FIXED_TACTIC_INDEX in [15, 16, 17, 18, 19, 20]: + # with AutoTuner.get().replay(tactics_list[FIXED_TACTIC_INDEX]), torch.inference_mode(): + # output = fused_moe.forward(x, router_logits) + # i = 0 + # for ctx_idx, (runner, tactic_value) in enumerate(tactics_list[FIXED_TACTIC_INDEX]): + # custom_op = all_tactics._captured_contexts[ctx_idx]['custom_op'] + # gemm_idx = 1 if "gemm1" in custom_op else 2 + # desc = cpp_runner.get_tactic_desc(gemm_idx, tactic_value) + # print(f"{i} Op: {custom_op}, Runner: {type(runner).__name__}, Tactic: {tactic_value}") + # print(f" -> {desc}", end='') + # i+=1 + # print(f"output: {output}") + # print(f"ref_output: {ref_output}") + + # # check_accuracy(output, ref_output, rtol=1e-2, atol=0.1, percent=0.99) + # # print(f" => passed.") + # Test all kernel tactics for tactic in all_tactics: with AutoTuner.get().replay(tactic), torch.inference_mode(): output = fused_moe.forward(x, router_logits) - check_accuracy(output, - ref_output, - rtol=1e-2, - atol=0.1, - percent=0.99) + # i = 0 + # for ctx_idx, (runner, tactic_value) in enumerate(tactic): + # custom_op = all_tactics._captured_contexts[ctx_idx]['custom_op'] + # gemm_idx = 1 if "gemm1" in custom_op else 2 + # desc = cpp_runner.get_tactic_desc(gemm_idx, tactic_value) + # print(f"{i} Op: {custom_op}, Runner: {type(runner).__name__}, Tactic: {tactic_value}") + # print(f" -> {desc}", end='') + # i+=1 + # print(f"output: {output}") + # print(f"ref_output: {ref_output}") + check_accuracy(output, ref_output, rtol=1e-2, atol=0.1, percent=0.99) + # print(f" => passed.") # compare torch.cuda.synchronize() @@ -527,15 +536,16 @@ def ref(): @pytest.mark.parametrize("fp8_activation", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("dynamic_quant", [True, False]) -def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size, - fp8_activation, bias, dynamic_quant): +def test_fused_moe_triton_mxfp4( + experts, hidden_size, intermediate_size, fp8_activation, bias, dynamic_quant +): if fp8_activation: pytest.skip("Latest Triton requires BF16 activation on Hopper") mapping = Mapping() mapping.rank = mpi_rank() - with torch.device(f'cuda:{mapping.rank}'): + with torch.device(f"cuda:{mapping.rank}"): dtype = torch.bfloat16 SEQ_LEN = 8 HIDDEN_SIZE = hidden_size @@ -548,26 +558,21 @@ def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size, x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() - w1_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype).cuda() - w2_weight = torch.randn((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), - dtype=dtype).cuda() - w3_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), - dtype=dtype).cuda() - w1_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), - dtype=dtype).cuda() + w1_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() + w2_weight = torch.randn((NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), dtype=dtype).cuda() + w3_weight = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() + w1_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), dtype=dtype).cuda() w2_bias = torch.randn((NUM_EXPERTS, HIDDEN_SIZE), dtype=dtype).cuda() - w3_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), - dtype=dtype).cuda() + w3_bias = torch.randn((NUM_EXPERTS, INTERMEDIATE_SIZE), dtype=dtype).cuda() from triton_kernels.numerics_details.mxfp import ( - downcast_to_mxfp_torch, upcast_from_mxfp_torch) + downcast_to_mxfp_torch, + upcast_from_mxfp_torch, + ) def fp32_to_mxfp4(tensor): tensor = tensor.transpose(1, 2).contiguous() - tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, - torch.uint8, - axis=1) + tensor_fp4, tensor_scales = downcast_to_mxfp_torch(tensor, torch.uint8, axis=1) tensor_fp4 = tensor_fp4.transpose(1, 2).contiguous() tensor_scales = tensor_scales.transpose(1, 2).contiguous() return tensor_fp4, tensor_scales @@ -575,10 +580,7 @@ def fp32_to_mxfp4(tensor): def mxfp4_to_fp32(tensor, scales): tensor = tensor.transpose(1, 2).contiguous() scales = scales.transpose(1, 2).contiguous() - tensor = upcast_from_mxfp_torch(tensor, - scales, - torch.float32, - axis=1) + tensor = upcast_from_mxfp_torch(tensor, scales, torch.float32, axis=1) return tensor.transpose(1, 2).contiguous() w1_weight_fp4, w1_weight_scale = fp32_to_mxfp4(w1_weight) @@ -599,13 +601,15 @@ def mxfp4_to_fp32(tensor, scales): weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] weights[f"{expert_id}.w3.bias"] = w3_bias[expert_id] - ref_fused_moe = RefGatedMLPFusedMoE(num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - model_config=ModelConfig(), - bias=bias) + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(), + bias=bias, + ) ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda() @@ -625,12 +629,9 @@ def mxfp4_to_fp32(tensor, scales): weights[f"{expert_id}.w1.weight"] = w1_weight_fp4[expert_id] weights[f"{expert_id}.w2.weight"] = w2_weight_fp4[expert_id] weights[f"{expert_id}.w3.weight"] = w3_weight_fp4[expert_id] - weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[ - expert_id] - weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[ - expert_id] - weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[ - expert_id] + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale[expert_id] + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale[expert_id] + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale[expert_id] if bias: weights[f"{expert_id}.w1.bias"] = w1_bias[expert_id] weights[f"{expert_id}.w2.bias"] = w2_bias[expert_id] @@ -638,16 +639,16 @@ def mxfp4_to_fp32(tensor, scales): quant_algo = QuantAlgo.W4A8_MXFP4_FP8 if fp8_activation else QuantAlgo.W4A16_MXFP4 quant_config = QuantConfig(quant_algo=quant_algo) - fused_moe = TritonFusedMoE(num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - reduce_results=True, - bias=bias, - model_config=ModelConfig( - quant_config=quant_config, - mapping=mapping)) + fused_moe = TritonFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + bias=bias, + model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + ) fused_moe.load_weights([weights]) fused_moe.cuda() @@ -663,19 +664,20 @@ def mxfp4_to_fp32(tensor, scales): class RefGatedMLPFusedMoE(nn.Module): - - def __init__(self, - num_experts: int, - routing_method: BaseMoeRoutingMethod, - hidden_size: int, - intermediate_size: int, - dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig(), - use_cute_dsl_blockscaling_mm: bool = False, - bias=False, - swiglu_alpha: Optional[float] = None, - swiglu_beta: Optional[float] = None, - swiglu_limit: Optional[float] = None): + def __init__( + self, + num_experts: int, + routing_method: BaseMoeRoutingMethod, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + model_config: ModelConfig = ModelConfig(), + use_cute_dsl_blockscaling_mm: bool = False, + bias=False, + swiglu_alpha: Optional[float] = None, + swiglu_beta: Optional[float] = None, + swiglu_limit: Optional[float] = None, + ): super().__init__() self.num_experts = num_experts self.routing_method = routing_method @@ -699,30 +701,30 @@ def custom_swiglu(x): return gate_act * (value + beta) - self.experts = nn.ModuleList([ - GatedMLP( - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - bias=bias, - dtype=self.dtype, - config=model_config, - use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, - activation=custom_swiglu - if swiglu_alpha is not None else F.silu, - ) for _ in range(self.num_experts) - ]) - - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> torch.Tensor: + self.experts = nn.ModuleList( + [ + GatedMLP( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + bias=bias, + dtype=self.dtype, + config=model_config, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, + activation=custom_swiglu if swiglu_alpha is not None else F.silu, + ) + for _ in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: assert hidden_states.shape[-1] == self.hidden_size hidden_states = hidden_states.view(-1, self.hidden_size) - selected_experts, routing_weights = self.routing_method.apply( - router_logits) + selected_experts, routing_weights = self.routing_method.apply(router_logits) - final_hidden_states = torch.zeros(hidden_states.shape, - dtype=hidden_states.dtype, - device=hidden_states.device) + final_hidden_states = torch.zeros( + hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device + ) for expert_id in range(self.num_experts): if not torch.any(selected_experts == expert_id): @@ -731,8 +733,9 @@ def forward(self, hidden_states: torch.Tensor, expert_inputs = hidden_states[batch_idx] output = self.experts[expert_id](expert_inputs) - final_hidden_states[batch_idx] += routing_weights[ - batch_idx, nth_expert, None] * output.float() + final_hidden_states[batch_idx] += ( + routing_weights[batch_idx, nth_expert, None] * output.float() + ) final_hidden_states = final_hidden_states.reshape(hidden_states.shape) return final_hidden_states @@ -745,63 +748,42 @@ def load_weights(self, weights: List[Dict]): gate_up_proj_weights = [{}, {}] down_proj_weights = [{}] - gate_up_proj_weights[0]['weight'] = weights[f"{expert}.w1.weight"] - gate_up_proj_weights[1]['weight'] = weights[f"{expert}.w3.weight"] - down_proj_weights[0]['weight'] = weights[f"{expert}.w2.weight"] + gate_up_proj_weights[0]["weight"] = weights[f"{expert}.w1.weight"] + gate_up_proj_weights[1]["weight"] = weights[f"{expert}.w3.weight"] + down_proj_weights[0]["weight"] = weights[f"{expert}.w2.weight"] if self.bias: - gate_up_proj_weights[0]['bias'] = weights[f"{expert}.w1.bias"] - gate_up_proj_weights[1]['bias'] = weights[f"{expert}.w3.bias"] - down_proj_weights[0]['bias'] = weights[f"{expert}.w2.bias"] + gate_up_proj_weights[0]["bias"] = weights[f"{expert}.w1.bias"] + gate_up_proj_weights[1]["bias"] = weights[f"{expert}.w3.bias"] + down_proj_weights[0]["bias"] = weights[f"{expert}.w2.bias"] if self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8: - gate_up_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w1.weight_scale"] - gate_up_proj_weights[1]['weight_scale'] = weights[ - f"{expert}.w3.weight_scale"] - down_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w2.weight_scale"] - gate_up_proj_weights[0]['input_scale'] = weights[ - f"{expert}.w1.input_scale"] - gate_up_proj_weights[1]['input_scale'] = weights[ - f"{expert}.w3.input_scale"] - down_proj_weights[0]['input_scale'] = weights[ - f"{expert}.w2.input_scale"] + gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"] + down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"] + gate_up_proj_weights[0]["input_scale"] = weights[f"{expert}.w1.input_scale"] + gate_up_proj_weights[1]["input_scale"] = weights[f"{expert}.w3.input_scale"] + down_proj_weights[0]["input_scale"] = weights[f"{expert}.w2.input_scale"] elif self.quant_config and self.quant_config.quant_algo in ( - QuantAlgo.NVFP4, QuantAlgo.W4A8_NVFP4_FP8): - gate_up_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w1.weight_scale"] - gate_up_proj_weights[1]['weight_scale'] = weights[ - f"{expert}.w3.weight_scale"] - down_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w2.weight_scale"] - gate_up_proj_weights[0]['input_scale'] = weights[ - f"{expert}.w1.input_scale"] - gate_up_proj_weights[1]['input_scale'] = weights[ - f"{expert}.w3.input_scale"] - down_proj_weights[0]['input_scale'] = weights[ - f"{expert}.w2.input_scale"] - gate_up_proj_weights[0]['weight_scale_2'] = weights[ - f"{expert}.w1.weight_scale_2"] - gate_up_proj_weights[1]['weight_scale_2'] = weights[ - f"{expert}.w3.weight_scale_2"] - down_proj_weights[0]['weight_scale_2'] = weights[ - f"{expert}.w2.weight_scale_2"] - elif (self.quant_config and self.quant_config.quant_algo - == QuantAlgo.FP8_BLOCK_SCALES): - gate_up_proj_weights[0]["weight_scale"] = weights[ - f"{expert}.w1.weight_scale"] - gate_up_proj_weights[1]["weight_scale"] = weights[ - f"{expert}.w3.weight_scale"] - down_proj_weights[0]["weight_scale"] = weights[ - f"{expert}.w2.weight_scale"] + QuantAlgo.NVFP4, + QuantAlgo.W4A8_NVFP4_FP8, + ): + gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"] + down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"] + gate_up_proj_weights[0]["input_scale"] = weights[f"{expert}.w1.input_scale"] + gate_up_proj_weights[1]["input_scale"] = weights[f"{expert}.w3.input_scale"] + down_proj_weights[0]["input_scale"] = weights[f"{expert}.w2.input_scale"] + gate_up_proj_weights[0]["weight_scale_2"] = weights[f"{expert}.w1.weight_scale_2"] + gate_up_proj_weights[1]["weight_scale_2"] = weights[f"{expert}.w3.weight_scale_2"] + down_proj_weights[0]["weight_scale_2"] = weights[f"{expert}.w2.weight_scale_2"] + elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"] + down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"] elif self.quant_config and self.quant_config.quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: - gate_up_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w1.weight_scale"] - gate_up_proj_weights[1]['weight_scale'] = weights[ - f"{expert}.w3.weight_scale"] - down_proj_weights[0]['weight_scale'] = weights[ - f"{expert}.w2.weight_scale"] + gate_up_proj_weights[0]["weight_scale"] = weights[f"{expert}.w1.weight_scale"] + gate_up_proj_weights[1]["weight_scale"] = weights[f"{expert}.w3.weight_scale"] + down_proj_weights[0]["weight_scale"] = weights[f"{expert}.w2.weight_scale"] self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights) self.experts[expert].down_proj.load_weights(down_proj_weights) - From 79315f644952582b25a38cbb024fe5769c6851ed Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:35:42 +0000 Subject: [PATCH 6/6] Add 4-bit weights interleave functions Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../moe_gemm/moe_gemm_mixed_utils.cu | 119 ++++++++++++++++++ .../moe_gemm/moe_gemm_mixed_utils.h | 35 ++++++ 2 files changed, 154 insertions(+) create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu new file mode 100644 index 000000000000..d9be8aa60fd4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_mixed_utils.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::cutlass_kernels +{ + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void interleave_fp4_weights_for_Hopper_mixed_gemm_kernel( + uint8_t* fp4_weight, uint8_t* fp4_weight_interleaved, int const rows, int const cols) +{ + for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) + { + for (int partition_id = threadIdx.y; partition_id < cols / 64; partition_id += blockDim.y) + { + int lane_id = threadIdx.x; + int row_id = block_id / 8 * 16 + block_id % 8; + + int mma_id = lane_id / 8; + int dst_row_id = row_id + (mma_id % 2) * 8; + + int interleaved_lane_id = lane_id / 16 * 16 + (lane_id % 4) * 4 + (lane_id % 8) / 4 * 2; + + int col_id = partition_id * 32 + lane_id; + int dst_col_id = partition_id * 32 + interleaved_lane_id; + + int index_a = row_id * cols / 2 + col_id; + int index_b = (row_id + 8) * cols / 2 + col_id; + + uint8_t fp4x2_a = fp4_weight[index_a]; + uint8_t fp4x2_b = fp4_weight[index_b]; + + uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4; + uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4; + + fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b; + fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a; + + int dst_id = dst_row_id * cols / 2 + dst_col_id; + + fp4_weight_interleaved[dst_id] = fp4x2_a; + fp4_weight_interleaved[dst_id + 1] = fp4x2_b; + } + } +} + +__global__ void interleave_int4_weights_for_Hopper_mixed_gemm_kernel( + uint8_t* int4_weight, uint8_t* int4_weight_interleaved, int const rows, int const cols) +{ + uint16_t* uint16_ptr = reinterpret_cast(int4_weight); + uint16_t* uint16_interleaved_ptr = reinterpret_cast(int4_weight_interleaved); + + for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) + { + for (int partition_id = threadIdx.y; partition_id < cols / 64; partition_id += blockDim.y) + { + int lane_id = threadIdx.x; + + int row_id = block_id / 8 * 16 + block_id % 8; + int dst_row_id = row_id + (lane_id % 8) / 4 * 8; + + int mma_id = lane_id / 8; + int interleaved_lane_id = mma_id * 8 + lane_id % 4 * 2; + + int col_id = partition_id * 16 + lane_id; + int dst_col_id = partition_id * 16 + interleaved_lane_id; + + int src_id_a = row_id * cols / 4 + col_id; + int src_id_b = (row_id + 8) * cols / 4 + col_id; + + uint16_t int4x2_a = uint16_ptr[src_id_a]; + uint16_t int4x2_b = uint16_ptr[src_id_b]; + + int dst_id = dst_row_id * cols / 4 + dst_col_id; + + uint16_interleaved_ptr[dst_id] = int4x2_a; + uint16_interleaved_ptr[dst_id + 1] = int4x2_b; + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void interleave_fp4_weights_for_Hopper_mixed_gemm( + uint8_t* fp4_weight, uint8_t* fp4_weight_interleaved, int const rows, int const cols) +{ + dim3 block(32, 32); + interleave_fp4_weights_for_Hopper_mixed_gemm_kernel<<<1024, block>>>( + fp4_weight, fp4_weight_interleaved, rows, cols); +} + +void interleave_int4_weights_for_Hopper_mixed_gemm( + uint8_t* int4_weight, uint8_t* int4_weight_interleaved, int const rows, int const cols) +{ + dim3 block(16, 32); + interleave_int4_weights_for_Hopper_mixed_gemm_kernel<<<1024, block>>>( + int4_weight, int4_weight_interleaved, rows, cols); +} + +} // namespace kernels::cutlass_kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h new file mode 100644 index 000000000000..8ed11ba0ac3a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/config.h" +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels::cutlass_kernels +{ + +void interleave_fp4_weights_for_Hopper_mixed_gemm( + uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols); + +void interleave_int4_weights_for_Hopper_mixed_gemm( + uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols); + +} // namespace kernels::cutlass_kernels + +TRTLLM_NAMESPACE_END