diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5370b9e28bd2..a6d86e738462 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -35,6 +35,10 @@ marlin_moe_intermediate_size, marlin_quant_input, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + MARLIN_TILE_K, + MARLIN_TILE_N, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Static128BlockSym, @@ -88,12 +92,16 @@ def _fused_marlin_moe( M, K = hidden_states.size() N = marlin_moe_intermediate_size(w1, w2) w13_num_shards = 2 if activation.is_gated else 1 + _w13_n = w13_num_shards * N + # Compute the same tile-aligned padded sizes used at weight-load time. + _w13_n_padded = _w13_n + ((-_w13_n) % MARLIN_TILE_N) # for w13 GEMM size_n + _N_padded = N + ((-N) % MARLIN_TILE_K) # for w2 GEMM size_k if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) if intermediate_cache13 is None: intermediate_cache13 = torch.empty( - (M * num_topk * max(w13_num_shards * N, K),), + (M * num_topk * max(_w13_n_padded, K),), device=hidden_states.device, dtype=hidden_states.dtype, ) @@ -106,7 +114,7 @@ def _fused_marlin_moe( ) intermediate_cache1 = _resize_cache( - intermediate_cache13, (M * num_topk, w13_num_shards * N) + intermediate_cache13, (M * num_topk, _w13_n_padded) ) intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K)) @@ -143,17 +151,21 @@ def _fused_marlin_moe( mul_topk_weights=apply_router_weight_on_input, b_q_type=quant_type, size_m=M, - size_n=w13_num_shards * N, + size_n=_w13_n_padded, # padded to Marlin tile_n boundary size_k=K, is_k_full=is_k_full, use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, ) + # Trim w13 padding before activation (GEMM produced _w13_n_padded cols, + # activation expects true _w13_n cols). + if _w13_n_padded != _w13_n: + intermediate_cache1 = intermediate_cache1[:, :_w13_n].contiguous() activation_func( activation, intermediate_cache2, - intermediate_cache1.view(-1, w13_num_shards * N), + intermediate_cache1.view(-1, _w13_n), ) if output is None: @@ -174,6 +186,13 @@ def _fused_marlin_moe( intermediate_cache2, input_dtype ) + # Pad activation output to _N_padded so w2 GEMM size_k is tile-aligned. + # Extra columns are zero; the matching zero-padding in w2's repacked weights + # ensures they contribute nothing to the output. + if _N_padded != N: + intermediate_cache2 = torch.nn.functional.pad( + intermediate_cache2, (0, _N_padded - N) + ) output = ops.moe_wna16_marlin_gemm( intermediate_cache2, output, @@ -196,7 +215,7 @@ def _fused_marlin_moe( b_q_type=quant_type, size_m=M * num_topk, size_n=K, - size_k=N, + size_k=_N_padded, # padded to Marlin tile_k boundary is_k_full=is_k_full, use_atomic_add=False, use_fp32_reduce=True, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index b5a557ce999d..e87e43c743f2 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -3,6 +3,7 @@ import torch +import torch.nn.functional as F import vllm._custom_ops as ops from vllm.logger import init_logger @@ -20,6 +21,21 @@ logger = init_logger(__name__) +# Marlin kernel tile alignment requirements. +MARLIN_TILE_N = 64 # size_n must be divisible by this +MARLIN_TILE_K = 16 # size_k must be divisible by this + + +def _pad_to_marlin_tile(size_n: int, size_k: int) -> tuple[int, int, int, int]: + """Return (padded_size_n, padded_size_k, pad_n, pad_k). + + Computes the smallest tile-aligned sizes >= size_n and size_k. + pad_n / pad_k are zero when the dimension is already aligned. + """ + pad_n = (-size_n) % MARLIN_TILE_N + pad_k = (-size_k) % MARLIN_TILE_K + return size_n + pad_n, size_k + pad_k, pad_n, pad_k + def is_fp8_marlin_supported(): return current_platform.has_device_capability(75) @@ -247,12 +263,28 @@ def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor: assert weight.shape == (e, size_n, size_k) + # Pad size_n and size_k to Marlin tile boundaries so gptq_marlin_repack + # does not crash when TP sharding produces non-aligned per-rank dimensions: + # tile_n_size = 64 (affects w13 gate+up, e.g. 464 → 512) + # tile_k_size = 16 (affects w2 down-proj, e.g. 232 → 240) + _padded_size_n, _padded_size_k, _pad_n, _pad_k = _pad_to_marlin_tile( + size_n, size_k + ) for i in range(e): qweight = pack_fp8_to_int32(weight[i], size_k_first=False) + # pad K before transposing: qweight shape is (size_n, size_k//4) + if _pad_k > 0: + qweight = F.pad(qweight, (0, _pad_k // 4)) qweight = qweight.T.contiguous() - + # pad N after transposing: qweight shape is (padded_size_k//4, size_n) + if _pad_n > 0: + qweight = F.pad(qweight, (0, _pad_n)) marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8 + b_q_weight=qweight, + perm=perm, + size_k=_padded_size_k, + size_n=_padded_size_n, + num_bits=8, ) tensor_list.append(marlin_qweight) @@ -302,9 +334,16 @@ def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor: # size_n may not divisible by block_size[0] scales = scales[..., :size_n].contiguous() + _padded_size_n, _padded_size_k, _pad_n, _ = _pad_to_marlin_tile(size_n, size_k) for i in range(e): + _s = scales[i] + if _pad_n > 0: + _s = F.pad(_s, (0, _pad_n)) marlin_scales = marlin_permute_scales( - s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size + s=_s, + size_k=_padded_size_k, + size_n=_padded_size_n, + group_size=group_size, ) tensor_list.append(marlin_scales)