Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
base_activation_type, parallelism_config, min_latency_mode);

auto const quant_params =
getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size,
quant_scales, base_activation_type);
kernels::MoeMinLatencyParams min_latency_params{};

// TODO: support lora in the future
Expand Down Expand Up @@ -542,8 +542,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
base_activation_type, parallelism_config, min_latency_mode);

auto const quant_params =
getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size,
quant_scales, base_activation_type);

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
Expand Down Expand Up @@ -809,9 +809,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
return info;
}

kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size,
int64_t inter_size,
Optional<Array<Tensor>> quant_scales) const {
kernels::QuantParams getQuantParams(
int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size,
Optional<Array<Tensor>> quant_scales,
ActivationType base_activation_type = ActivationType::Swiglu) const {
if (isFp8Quant()) {
TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization";
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
Expand Down Expand Up @@ -1013,18 +1014,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
// Check shapes
TVM_FFI_ICHECK(fc1_act_global.ndim() == 0 || fc1_act_global.size(0) == num_experts_on_rank)
<< "fc1 act global must be scalar or (num_experts_on_rank,)";
TVM_FFI_ICHECK(
fc1_weight_block.size(0) == num_experts_on_rank &&
fc1_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
2 &&
fc1_weight_block.size(2) * FP8_PER_INT32 *
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
"// block_scale_vector_size)";
if (isGatedActivation(base_activation_type)) {
TVM_FFI_ICHECK(
fc1_weight_block.size(0) == num_experts_on_rank &&
fc1_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
2 &&
fc1_weight_block.size(2) * FP8_PER_INT32 *
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // "
"4 "
"// block_scale_vector_size)";
} else {
TVM_FFI_ICHECK(
fc1_weight_block.size(0) == num_experts_on_rank &&
fc1_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) &&
fc1_weight_block.size(2) * FP8_PER_INT32 *
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
<< "fc1 weight block size must be (num_experts_on_rank, inter_size, hidden_size // 4 "
"// block_scale_vector_size)";
}

TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank)
<< "fc1 global size must be (num_experts_on_rank,)";
TVM_FFI_ICHECK(fc2_act_global.ndim() == 0 || fc2_act_global.size(0) == num_experts_on_rank)
Expand Down
57 changes: 40 additions & 17 deletions tests/moe/test_trtllm_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import nullcontext

import pytest
from flashinfer.fused_moe.core import ActivationType
import torch
from torch.nn import functional as F

Expand Down Expand Up @@ -137,7 +138,7 @@ def compute_routing(
return routing_weights, selected_experts


def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
Expand All @@ -147,13 +148,26 @@ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
topk_ids = topk_ids.view(-1)
# w1 needs to be swapped in terms of gate and up_proj

if activation_type == ActivationType.Swiglu:

def act(weight, mask):
m = weight.shape[0]
assert m % 2 == 0
w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :]
return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())

elif activation_type == ActivationType.Relu2:

def act(weight, mask):
return F.relu(a[mask] @ weight.t()) ** 2

else:
raise ValueError(f"Unsupported activation type {activation_type}")

for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
m = w1[i].shape[0]
assert m % 2 == 0
w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
inter = act(w1[i], mask)
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
Expand Down Expand Up @@ -363,6 +377,11 @@ def test_moe_fp8(
[(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
)
@pytest.mark.parametrize("quantized_input", [False, True])
@pytest.mark.parametrize(
"activation_type",
[ActivationType.Swiglu, ActivationType.Relu2],
ids=["swiglu", "relu2"],
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="NVFP4 is only supported on SM100, SM110 and SM120",
Expand All @@ -376,6 +395,7 @@ def test_moe_nvfp4(
otype,
wtype,
quantized_input,
activation_type,
):
# Skip invalid configurations
if top_k > num_experts:
Expand All @@ -391,10 +411,10 @@ def test_moe_nvfp4(
n = intermediate_size
k = hidden_size

w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
w1_n = 2 * n if activation_type == ActivationType.Swiglu else n
w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10

sf_w1_2n = round_up(2 * n, 128)
sf_w1_2n = round_up(w1_n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
Expand All @@ -409,8 +429,8 @@ def test_moe_nvfp4(
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
w1_q_cutlass = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
Expand All @@ -424,7 +444,7 @@ def test_moe_nvfp4(
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])

w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
w1_cutlass[expert], w1_gs[expert]
w1[expert], w1_gs[expert]
)

w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
Expand Down Expand Up @@ -469,6 +489,7 @@ def test_moe_nvfp4(
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
activation_type=activation_type,
)

# Ref check
Expand All @@ -483,7 +504,7 @@ def test_moe_nvfp4(
block_size=quant_blocksize,
)

w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype)
w1_d = torch.empty((e, w1_n, k), device="cuda", dtype=otype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=otype)

for idx in range(0, e):
Expand All @@ -504,12 +525,14 @@ def test_moe_nvfp4(
block_size=quant_blocksize,
)

w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous()
w1_blockscale_cutlass = torch.cat(
(w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1
).contiguous()
ref_output = torch_moe_nvfp4(
a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts
a_in_dtype,
w1_d,
w2_d,
top_k,
routing_weights,
selected_experts,
activation_type,
)
torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)

Expand Down