fix marlin fp4 kernel N-dimension alignment#37296
fix marlin fp4 kernel N-dimension alignment#37296flutist wants to merge 6 commits intovllm-project:mainfrom
Conversation
Signed-off-by: xjx <493337577@qq.com>
There was a problem hiding this comment.
Code Review
This pull request introduces padding for the N-dimension to meet alignment requirements for the Marlin FP4 kernel, which is a necessary fix. The core logic changes in vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py appear to correctly implement this padding. However, I've identified a critical bug in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py due to an incorrect getattr call that will lead to a runtime error. I've also pointed out a typo in a function name that should be corrected for code clarity.
| N = marlin_moe_intermediate_size(w1, w2) | ||
| N = marlin_moe_intermediate_size(w1, w2, layer) | ||
| w13_num_shards = 2 if activation.is_gated else 1 | ||
| w13_size_n = getattr(layer, "marlin_moe_w13_size_n", w13_num_shards, *N) |
There was a problem hiding this comment.
This line will cause a TypeError at runtime because it attempts to unpack an integer N with *N. Additionally, layer can be None (e.g., when called from batched_fused_marlin_moe), which would cause an AttributeError on getattr. The logic should safely retrieve marlin_moe_w13_size_n from the layer if it exists, and fall back to the computed value otherwise.
| w13_size_n = getattr(layer, "marlin_moe_w13_size_n", w13_num_shards, *N) | |
| if layer and hasattr(layer, "marlin_moe_w13_size_n"): | |
| w13_size_n = layer.marlin_moe_w13_size_n | |
| else: | |
| w13_size_n = w13_num_shards * N |
|
|
||
| # WEIGHT SCALES | ||
| # Permute scales | ||
| def premute_scales( |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: xjx <493337577@qq.com>
…alignment' into fix_marlin_fp4_kernel_dimension_alignment
Signed-off-by: xjx <493337577@qq.com>
Purpose
Fix Marlin FP4 kernel N-dimension alignment
When execute
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 vllm serve "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" --trust_remote_code -tp 2,terminal show error when tp=2.
The Marlin kernel requires size_n to be divisible by 64. This PR adds N-dimension zero-padding during weight preparation so FP4 models with arbitrary intermediate sizes work correctly.
Changes:
Add MARLIN_TILE_N = 64 constant and _pad_to_marlin_tile() helper in marlin_utils_fp4.py
Pad weight, weight scale, and bias tensors to the next multiple of 64 in prepare_fp4_layer_for_marlin, prepare_nvfp4_moe_layer_for_marlin, and prepare_moe_fp4_layer_for_marlin
Slice output back to original size_n after GEMM in apply_fp4_marlin_linear
Store padded sizes (marlin_moe_w13_size_n, marlin_moe_intermediate_size) on the layer and propagate them via process_weights_after_loading in MarlinExpertsBase
Extend marlin_moe_intermediate_size() and fused_marlin_moe() to accept an optional layer argument for reading the stored sizes
Test Result
After fix, everything work fine.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.