Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def __init__(self, quant_config: ModelOptFp4Config):
" above."
)
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
self._cache_permute_indices = {}

@property
def enable_flashinfer_cutlass_moe(self) -> bool:
Expand Down Expand Up @@ -900,10 +901,15 @@ def prepare_static_weights_for_kernel(
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
next_positive_power_of_2,
nvfp4_block_scale_interleave,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
)

"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
Expand All @@ -927,50 +933,66 @@ def prepare_static_weights_for_kernel(
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors

# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)

# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)

# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)

permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)

permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
)
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)

permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)

Expand Down
Loading