Skip to content

fix marlin fp4 kernel N-dimension alignment#37296

Open
flutist wants to merge 6 commits intovllm-project:mainfrom
flutist:fix_marlin_fp4_kernel_dimension_alignment
Open

fix marlin fp4 kernel N-dimension alignment#37296
flutist wants to merge 6 commits intovllm-project:mainfrom
flutist:fix_marlin_fp4_kernel_dimension_alignment

Conversation

@flutist
Copy link
Copy Markdown
Contributor

@flutist flutist commented Mar 17, 2026

Purpose

Fix Marlin FP4 kernel N-dimension alignment
When execute VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 vllm serve "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" --trust_remote_code -tp 2,
terminal show error when tp=2.

VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 vllm serve "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" --trust_remote_code -tp 2

(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297] 
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297]        █     █     █▄   ▄█
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297]  ▄▄ ▄█ █     █     █ ▀▄▀ █  version 0.1.dev14910+g20b14095a
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297]   █▄█▀ █     █     █     █  model   nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297]    ▀▀  ▀▀▀▀▀ ▀▀▀▀▀ ▀     ▀
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:297] 
(APIServer pid=461117) INFO 03-17 15:05:45 [utils.py:233] non-default args: {'model_tag': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', 'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', 'trust_remote_code': True, 'tensor_parallel_size': 2}
(APIServer pid=461117) INFO 03-17 15:05:47 [model.py:533] Resolved architecture: NemotronHForCausalLM
(APIServer pid=461117) INFO 03-17 15:05:47 [model.py:1582] Using max model len 262144
(APIServer pid=461117) INFO 03-17 15:05:47 [cache.py:212] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
(APIServer pid=461117) INFO 03-17 15:05:47 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(APIServer pid=461117) INFO 03-17 15:05:47 [config.py:427] Updating mamba_ssm_cache_dtype to 'float32' for NemotronH model
(APIServer pid=461117) INFO 03-17 15:05:47 [config.py:212] Setting attention block size to 4176 tokens to ensure that attention page size is >= mamba page size.
(APIServer pid=461117) INFO 03-17 15:05:47 [config.py:243] Padding mamba page size by 0.19% to ensure that mamba page size and attention page size are exactly equal.
(APIServer pid=461117) WARNING 03-17 15:05:47 [modelopt.py:995] Detected ModelOpt NVFP4 checkpoint. Please note that the format is experimental and could change in future.
(APIServer pid=461117) INFO 03-17 15:05:47 [vllm.py:754] Asynchronous scheduling is enabled.
(APIServer pid=461117) INFO 03-17 15:05:47 [compilation.py:289] Enabled custom fusions: act_quant
(EngineCore pid=461501) INFO 03-17 15:05:55 [core.py:103] Initializing a V1 LLM engine (v0.1.dev14910+g20b14095a) with config: model='nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', speculative_config=None, tokenizer='nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=262144, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=False, quantization=modelopt_fp4, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=fp8_e4m3, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4, enable_prefix_caching=False, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.VLLM_COMPILE: 3>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['none'], 'splitting_ops': ['vllm::unified_attention', 'vllm::unified_attention_with_output', 'vllm::unified_mla_attention', 'vllm::unified_mla_attention_with_output', 'vllm::mamba_mixer2', 'vllm::mamba_mixer', 'vllm::short_conv', 'vllm::linear_attention', 'vllm::plamo2_mamba_mixer', 'vllm::gdn_attention_core', 'vllm::olmo_hybrid_gdn_full_forward', 'vllm::kda_attention', 'vllm::sparse_attn_indexer', 'vllm::rocm_aiter_sparse_attn_indexer', 'vllm::unified_kv_cache_update', 'vllm::unified_mla_kv_cache_update'], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_endpoints': [2048], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.FULL_AND_PIECEWISE: (2, 1)>, 'cudagraph_num_of_warmups': 1, 'cudagraph_capture_sizes': [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': 512, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore pid=461501) WARNING 03-17 15:05:55 [multiproc_executor.py:997] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
(EngineCore pid=461501) INFO 03-17 15:05:55 [multiproc_executor.py:134] DP group leader: node_rank=0, node_rank_within_dp=0, master_addr=127.0.0.1, mq_connect_ip=33.1.35.33 (local), world_size=2, local_world_size=2
(Worker pid=461780) INFO 03-17 15:06:01 [parallel_state.py:1395] world_size=2 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:55435 backend=nccl
(Worker pid=461912) INFO 03-17 15:06:06 [parallel_state.py:1395] world_size=2 rank=1 local_rank=1 distributed_init_method=tcp://127.0.0.1:55435 backend=nccl
(Worker pid=461912) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=461912) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=461780) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=461780) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=461780) INFO 03-17 15:06:07 [pynccl.py:111] vLLM is using nccl==2.27.5
(Worker pid=461780) WARNING 03-17 15:06:07 [symm_mem.py:67] SymmMemCommunicator: Device capability 8.9 not supported, communicator is not available.
(Worker pid=461912) WARNING 03-17 15:06:07 [symm_mem.py:67] SymmMemCommunicator: Device capability 8.9 not supported, communicator is not available.
(Worker pid=461780) INFO 03-17 15:06:07 [parallel_state.py:1717] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank 0, EPLB rank N/A
(Worker pid=461912) INFO 03-17 15:06:07 [parallel_state.py:1717] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 1, EP rank 1, EPLB rank N/A
(Worker_TP0 pid=461780) INFO 03-17 15:06:08 [gpu_model_runner.py:4481] Starting to load model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4...
(Worker_TP0 pid=461780) INFO 03-17 15:06:08 [nvfp4_utils.py:85] Using NvFp4LinearBackend.MARLIN for NVFP4 GEMM
(Worker_TP1 pid=461912) INFO 03-17 15:06:08 [nvfp4_utils.py:85] Using NvFp4LinearBackend.MARLIN for NVFP4 GEMM
(Worker_TP0 pid=461780) INFO 03-17 15:06:08 [nvfp4.py:257] Using 'MARLIN' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN'].
(Worker_TP0 pid=461780) INFO 03-17 15:06:08 [cuda.py:317] Using FLASHINFER attention backend out of potential backends: ['FLASHINFER', 'TRITON_ATTN'].
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:03,  1.04it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:02<00:03,  1.16s/it]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:03<00:02,  1.23s/it]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:04<00:01,  1.26s/it]
(Worker_TP1 pid=461912) WARNING 03-17 15:06:16 [nvfp4_utils.py:144] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.21s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.20s/it]
(Worker_TP0 pid=461780) 
(Worker_TP0 pid=461780) INFO 03-17 15:06:16 [default_loader.py:373] Loading weights took 6.10 seconds
(Worker_TP0 pid=461780) WARNING 03-17 15:06:17 [nvfp4_utils.py:144] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852] WorkerProc failed to start.
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852] Traceback (most recent call last):
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 821, in worker_main
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     worker = WorkerProc(*args, **kwargs)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 619, in __init__
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.worker.load_model()
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/worker/gpu_worker.py", line 335, in load_model
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.model_runner.load_model(load_dummy_weights=dummy_weights)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/worker/gpu_model_runner.py", line 4497, in load_model
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.model = model_loader.load_model(
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]                  ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/model_loader/base_loader.py", line 74, in load_model
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     process_weights_after_loading(model, model_config, target_device)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/model_loader/utils.py", line 106, in process_weights_after_loading
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     quant_method.process_weights_after_loading(module)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/modelopt.py", line 1168, in process_weights_after_loading
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     convert_to_nvfp4_linear_kernel_format(self.backend, layer)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py", line 150, in convert_to_nvfp4_linear_kernel_format
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     prepare_fp4_layer_for_marlin(layer)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py", line 176, in prepare_fp4_layer_for_marlin
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     marlin_qweight = ops.gptq_marlin_repack(
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]                      ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/_custom_ops.py", line 1259, in gptq_marlin_repack
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return torch.ops._C.gptq_marlin_repack(
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return self._op(*args, **kwargs)
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=461912) ERROR 03-17 15:06:17 [multiproc_executor.py:852] RuntimeError: size_n = 5152 is not divisible by tile_n_size = 64
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852] WorkerProc failed to start.
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852] Traceback (most recent call last):
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 821, in worker_main
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     worker = WorkerProc(*args, **kwargs)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 619, in __init__
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.worker.load_model()
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/worker/gpu_worker.py", line 335, in load_model
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.model_runner.load_model(load_dummy_weights=dummy_weights)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/worker/gpu_model_runner.py", line 4497, in load_model
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     self.model = model_loader.load_model(
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]                  ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return func(*args, **kwargs)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/model_loader/base_loader.py", line 74, in load_model
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     process_weights_after_loading(model, model_config, target_device)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/model_loader/utils.py", line 106, in process_weights_after_loading
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     quant_method.process_weights_after_loading(module)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/modelopt.py", line 1168, in process_weights_after_loading
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     convert_to_nvfp4_linear_kernel_format(self.backend, layer)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py", line 150, in convert_to_nvfp4_linear_kernel_format
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     prepare_fp4_layer_for_marlin(layer)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py", line 176, in prepare_fp4_layer_for_marlin
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     marlin_qweight = ops.gptq_marlin_repack(
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]                      ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/_custom_ops.py", line 1259, in gptq_marlin_repack
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return torch.ops._C.gptq_marlin_repack(
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]     return self._op(*args, **kwargs)
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=461780) ERROR 03-17 15:06:17 [multiproc_executor.py:852] RuntimeError: size_n = 5152 is not divisible by tile_n_size = 64
[rank0]:[W317 15:06:18.459733056 ProcessGroupNCCL.cpp:1553] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099] EngineCore failed to start.
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099] Traceback (most recent call last):
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 1073, in run_engine_core
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 839, in __init__
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     super().__init__(
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 112, in __init__
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     self.model_executor = executor_class(vllm_config)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 101, in __init__
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     super().__init__(vllm_config)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     return func(*args, **kwargs)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/abstract.py", line 103, in __init__
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     self._init_executor()
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 190, in _init_executor
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     self.workers = WorkerProc.wait_for_ready(unready_workers)
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 731, in wait_for_ready
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099]     raise e from None
(EngineCore pid=461501) ERROR 03-17 15:06:19 [core.py:1099] Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
(EngineCore pid=461501) Process EngineCore:
(EngineCore pid=461501) Traceback (most recent call last):
(EngineCore pid=461501)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore pid=461501)     self.run()
(EngineCore pid=461501)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore pid=461501)     self._target(*self._args, **self._kwargs)
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 1103, in run_engine_core
(EngineCore pid=461501)     raise e
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 1073, in run_engine_core
(EngineCore pid=461501)     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
(EngineCore pid=461501)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=461501)     return func(*args, **kwargs)
(EngineCore pid=461501)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 839, in __init__
(EngineCore pid=461501)     super().__init__(
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core.py", line 112, in __init__
(EngineCore pid=461501)     self.model_executor = executor_class(vllm_config)
(EngineCore pid=461501)                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 101, in __init__
(EngineCore pid=461501)     super().__init__(vllm_config)
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(EngineCore pid=461501)     return func(*args, **kwargs)
(EngineCore pid=461501)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/abstract.py", line 103, in __init__
(EngineCore pid=461501)     self._init_executor()
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 190, in _init_executor
(EngineCore pid=461501)     self.workers = WorkerProc.wait_for_ready(unready_workers)
(EngineCore pid=461501)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=461501)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/executor/multiproc_executor.py", line 731, in wait_for_ready
(EngineCore pid=461501)     raise e from None
(EngineCore pid=461501) Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
(APIServer pid=461117) Traceback (most recent call last):
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/bin/vllm", line 6, in <module>
(APIServer pid=461117)     sys.exit(main())
(APIServer pid=461117)              ^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/cli/main.py", line 75, in main
(APIServer pid=461117)     args.dispatch_function(args)
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/cli/serve.py", line 118, in cmd
(APIServer pid=461117)     uvloop.run(run_server(args))
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/site-packages/uvloop/__init__.py", line 96, in run
(APIServer pid=461117)     return __asyncio.run(
(APIServer pid=461117)            ^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/asyncio/runners.py", line 195, in run
(APIServer pid=461117)     return runner.run(main)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/asyncio/runners.py", line 118, in run
(APIServer pid=461117)     return self._loop.run_until_complete(task)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/site-packages/uvloop/__init__.py", line 48, in wrapper
(APIServer pid=461117)     return await main
(APIServer pid=461117)            ^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/openai/api_server.py", line 656, in run_server
(APIServer pid=461117)     await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/openai/api_server.py", line 670, in run_server_worker
(APIServer pid=461117)     async with build_async_engine_client(
(APIServer pid=461117)                ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/contextlib.py", line 210, in __aenter__
(APIServer pid=461117)     return await anext(self.gen)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/openai/api_server.py", line 103, in build_async_engine_client
(APIServer pid=461117)     async with build_async_engine_client_from_engine_args(
(APIServer pid=461117)                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/contextlib.py", line 210, in __aenter__
(APIServer pid=461117)     return await anext(self.gen)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/entrypoints/openai/api_server.py", line 144, in build_async_engine_client_from_engine_args
(APIServer pid=461117)     async_llm = AsyncLLM.from_vllm_config(
(APIServer pid=461117)                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/async_llm.py", line 225, in from_vllm_config
(APIServer pid=461117)     return cls(
(APIServer pid=461117)            ^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/async_llm.py", line 154, in __init__
(APIServer pid=461117)     self.engine_core = EngineCoreClient.make_async_mp_client(
(APIServer pid=461117)                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=461117)     return func(*args, **kwargs)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core_client.py", line 128, in make_async_mp_client
(APIServer pid=461117)     return AsyncMPClient(*client_args)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/tracing/otel.py", line 178, in sync_wrapper
(APIServer pid=461117)     return func(*args, **kwargs)
(APIServer pid=461117)            ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core_client.py", line 924, in __init__
(APIServer pid=461117)     super().__init__(
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/core_client.py", line 583, in __init__
(APIServer pid=461117)     with launch_core_engines(
(APIServer pid=461117)          ^^^^^^^^^^^^^^^^^^^^
(APIServer pid=461117)   File "/home/admin/miniconda3/envs/for_test/lib/python3.12/contextlib.py", line 144, in __exit__
(APIServer pid=461117)     next(self.gen)
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/utils.py", line 972, in launch_core_engines
(APIServer pid=461117)     wait_for_engine_startup(
(APIServer pid=461117)   File "/home/admin/workspace/aop_lab/app_source/vllm_custom_dataset_img_support_base64/vllm/v1/engine/utils.py", line 1031, in wait_for_engine_startup
(APIServer pid=461117)     raise RuntimeError(
(APIServer pid=461117) RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
/home/admin/miniconda3/envs/for_test/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '



The Marlin kernel requires size_n to be divisible by 64. This PR adds N-dimension zero-padding during weight preparation so FP4 models with arbitrary intermediate sizes work correctly.

Changes:

Add MARLIN_TILE_N = 64 constant and _pad_to_marlin_tile() helper in marlin_utils_fp4.py
Pad weight, weight scale, and bias tensors to the next multiple of 64 in prepare_fp4_layer_for_marlin, prepare_nvfp4_moe_layer_for_marlin, and prepare_moe_fp4_layer_for_marlin
Slice output back to original size_n after GEMM in apply_fp4_marlin_linear
Store padded sizes (marlin_moe_w13_size_n, marlin_moe_intermediate_size) on the layer and propagate them via process_weights_after_loading in MarlinExpertsBase
Extend marlin_moe_intermediate_size() and fused_marlin_moe() to accept an optional layer argument for reading the stored sizes

Test Result

After fix, everything work fine.

VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 vllm serve "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" --trust_remote_code -tp 2
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297] 
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297]        █     █     █▄   ▄█
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297]  ▄▄ ▄█ █     █     █ ▀▄▀ █  version 0.1.dev14910+g20b14095a
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297]   █▄█▀ █     █     █     █  model   nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297]    ▀▀  ▀▀▀▀▀ ▀▀▀▀▀ ▀     ▀
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:297] 
(APIServer pid=645740) INFO 03-17 17:06:23 [utils.py:233] non-default args: {'model_tag': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', 'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', 'trust_remote_code': True, 'tensor_parallel_size': 2}
(APIServer pid=645740) INFO 03-17 17:06:25 [model.py:533] Resolved architecture: NemotronHForCausalLM
(APIServer pid=645740) INFO 03-17 17:06:25 [model.py:1582] Using max model len 262144
(APIServer pid=645740) INFO 03-17 17:06:25 [cache.py:212] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
(APIServer pid=645740) INFO 03-17 17:06:25 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(APIServer pid=645740) INFO 03-17 17:06:25 [config.py:427] Updating mamba_ssm_cache_dtype to 'float32' for NemotronH model
(APIServer pid=645740) INFO 03-17 17:06:25 [config.py:212] Setting attention block size to 4176 tokens to ensure that attention page size is >= mamba page size.
(APIServer pid=645740) INFO 03-17 17:06:25 [config.py:243] Padding mamba page size by 0.19% to ensure that mamba page size and attention page size are exactly equal.
(APIServer pid=645740) WARNING 03-17 17:06:25 [modelopt.py:995] Detected ModelOpt NVFP4 checkpoint. Please note that the format is experimental and could change in future.
(APIServer pid=645740) INFO 03-17 17:06:25 [vllm.py:754] Asynchronous scheduling is enabled.
(APIServer pid=645740) INFO 03-17 17:06:25 [compilation.py:289] Enabled custom fusions: act_quant
(EngineCore pid=646534) INFO 03-17 17:06:33 [core.py:103] Initializing a V1 LLM engine (v0.1.dev14910+g20b14095a) with config: model='nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', speculative_config=None, tokenizer='nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=262144, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=False, quantization=modelopt_fp4, enforce_eager=False, enable_return_routed_experts=False, kv_cache_dtype=fp8_e4m3, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4, enable_prefix_caching=False, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.VLLM_COMPILE: 3>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['none'], 'splitting_ops': ['vllm::unified_attention', 'vllm::unified_attention_with_output', 'vllm::unified_mla_attention', 'vllm::unified_mla_attention_with_output', 'vllm::mamba_mixer2', 'vllm::mamba_mixer', 'vllm::short_conv', 'vllm::linear_attention', 'vllm::plamo2_mamba_mixer', 'vllm::gdn_attention_core', 'vllm::olmo_hybrid_gdn_full_forward', 'vllm::kda_attention', 'vllm::sparse_attn_indexer', 'vllm::rocm_aiter_sparse_attn_indexer', 'vllm::unified_kv_cache_update', 'vllm::unified_mla_kv_cache_update'], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_endpoints': [2048], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.FULL_AND_PIECEWISE: (2, 1)>, 'cudagraph_num_of_warmups': 1, 'cudagraph_capture_sizes': [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': 512, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}
(EngineCore pid=646534) WARNING 03-17 17:06:33 [multiproc_executor.py:997] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
(EngineCore pid=646534) INFO 03-17 17:06:33 [multiproc_executor.py:134] DP group leader: node_rank=0, node_rank_within_dp=0, master_addr=127.0.0.1, mq_connect_ip=33.1.35.33 (local), world_size=2, local_world_size=2
(Worker pid=646817) INFO 03-17 17:06:38 [parallel_state.py:1395] world_size=2 rank=0 local_rank=0 distributed_init_method=tcp://127.0.0.1:59855 backend=nccl
(Worker pid=646936) INFO 03-17 17:06:43 [parallel_state.py:1395] world_size=2 rank=1 local_rank=1 distributed_init_method=tcp://127.0.0.1:59855 backend=nccl
(Worker pid=646936) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=646936) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=646817) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
(Worker pid=646817) <frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
(Worker pid=646817) INFO 03-17 17:06:44 [pynccl.py:111] vLLM is using nccl==2.27.5
(Worker pid=646817) WARNING 03-17 17:06:44 [symm_mem.py:67] SymmMemCommunicator: Device capability 8.9 not supported, communicator is not available.
(Worker pid=646936) WARNING 03-17 17:06:44 [symm_mem.py:67] SymmMemCommunicator: Device capability 8.9 not supported, communicator is not available.
(Worker pid=646936) INFO 03-17 17:06:44 [parallel_state.py:1717] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 1, EP rank 1, EPLB rank N/A
(Worker pid=646817) INFO 03-17 17:06:44 [parallel_state.py:1717] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank 0, EPLB rank N/A
(Worker_TP0 pid=646817) INFO 03-17 17:06:45 [gpu_model_runner.py:4506] Starting to load model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4...
(Worker_TP1 pid=646936) INFO 03-17 17:06:45 [nvfp4_utils.py:85] Using NvFp4LinearBackend.MARLIN for NVFP4 GEMM
(Worker_TP0 pid=646817) INFO 03-17 17:06:45 [nvfp4_utils.py:85] Using NvFp4LinearBackend.MARLIN for NVFP4 GEMM
(Worker_TP0 pid=646817) INFO 03-17 17:06:45 [nvfp4.py:257] Using 'MARLIN' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN'].
(Worker_TP0 pid=646817) INFO 03-17 17:06:45 [cuda.py:317] Using FLASHINFER attention backend out of potential backends: ['FLASHINFER', 'TRITON_ATTN'].
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:03,  1.09it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:02<00:03,  1.08s/it]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:03<00:02,  1.14s/it]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:04<00:01,  1.22s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.21s/it]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:05<00:00,  1.17s/it]
(Worker_TP0 pid=646817) 
(Worker_TP0 pid=646817) INFO 03-17 17:06:53 [default_loader.py:373] Loading weights took 5.95 seconds
(Worker_TP1 pid=646936) WARNING 03-17 17:06:53 [nvfp4_utils.py:144] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(Worker_TP0 pid=646817) WARNING 03-17 17:06:53 [nvfp4_utils.py:144] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(Worker_TP1 pid=646936) WARNING 03-17 17:06:53 [marlin_utils_fp4.py:257] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(Worker_TP0 pid=646817) WARNING 03-17 17:06:53 [marlin_utils_fp4.py:257] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
(Worker_TP1 pid=646936) INFO 03-17 17:06:53 [nvfp4.py:412] Using MoEPrepareAndFinalizeNoDPEPModular
(Worker_TP0 pid=646817) INFO 03-17 17:06:53 [nvfp4.py:412] Using MoEPrepareAndFinalizeNoDPEPModular
(Worker_TP1 pid=646936) WARNING 03-17 17:06:53 [kv_cache.py:94] Checkpoint does not provide a q scaling factor. Setting it to k_scale. This only matters for FP8 Attention backends (flash-attn or flashinfer).
(Worker_TP1 pid=646936) WARNING 03-17 17:06:53 [kv_cache.py:108] Using KV cache scaling factor 1.0 for fp8_e4m3. If this is unintended, verify that k/v_scale scaling factors are properly set in the checkpoint.
(Worker_TP0 pid=646817) WARNING 03-17 17:06:53 [kv_cache.py:94] Checkpoint does not provide a q scaling factor. Setting it to k_scale. This only matters for FP8 Attention backends (flash-attn or flashinfer).
(Worker_TP0 pid=646817) WARNING 03-17 17:06:53 [kv_cache.py:108] Using KV cache scaling factor 1.0 for fp8_e4m3. If this is unintended, verify that k/v_scale scaling factors are properly set in the checkpoint.
(Worker_TP0 pid=646817) INFO 03-17 17:06:55 [gpu_model_runner.py:4591] Model loading took 9.34 GiB memory and 9.380721 seconds
(Worker_TP0 pid=646817) WARNING 03-17 17:06:56 [decorators.py:304] Compiling model again due to a load failure from /home/admin/.cache/vllm/torch_compile_cache/torch_aot_compile/357473dcb99d3b7aac699da3cd4fdf53268f94f19476e9b2b9eb1ac426b15212/rank_0_0/model, reason: Source code has changed since the last compilation. Recompiling the model.
(Worker_TP1 pid=646936) WARNING 03-17 17:06:56 [decorators.py:304] Compiling model again due to a load failure from /home/admin/.cache/vllm/torch_compile_cache/torch_aot_compile/357473dcb99d3b7aac699da3cd4fdf53268f94f19476e9b2b9eb1ac426b15212/rank_1_0/model, reason: Source code has changed since the last compilation. Recompiling the model.
(Worker_TP0 pid=646817) INFO 03-17 17:07:00 [backends.py:988] Using cache directory: /home/admin/.cache/vllm/torch_compile_cache/87293e38bb/rank_0_0/backbone for vLLM's torch.compile
(Worker_TP0 pid=646817) INFO 03-17 17:07:00 [backends.py:1048] Dynamo bytecode transform time: 3.85 s
(Worker_TP1 pid=646936) INFO 03-17 17:07:00 [backends.py:371] Cache the graph of compile range (1, 2048) for later use
(Worker_TP0 pid=646817) INFO 03-17 17:07:00 [backends.py:371] Cache the graph of compile range (1, 2048) for later use
(Worker_TP0 pid=646817) INFO 03-17 17:07:02 [backends.py:387] Compiling a graph for compile range (1, 2048) takes 1.51 s
(Worker_TP0 pid=646817) INFO 03-17 17:07:03 [decorators.py:627] saved AOT compiled function to /home/admin/.cache/vllm/torch_compile_cache/torch_aot_compile/357473dcb99d3b7aac699da3cd4fdf53268f94f19476e9b2b9eb1ac426b15212/rank_0_0/model
(Worker_TP0 pid=646817) INFO 03-17 17:07:03 [monitor.py:48] torch.compile took 6.54 s in total
(Worker_TP0 pid=646817) INFO 03-17 17:07:03 [marlin_utils.py:446] Marlin kernel can achieve better performance for small size_n with experimental use_atomic_add feature. You can consider set environment variable VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.
(Worker_TP1 pid=646936) INFO 03-17 17:07:03 [marlin_utils.py:446] Marlin kernel can achieve better performance for small size_n with experimental use_atomic_add feature. You can consider set environment variable VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.
(Worker_TP0 pid=646817) INFO 03-17 17:07:04 [monitor.py:76] Initial profiling/warmup run took 0.88 s
(Worker_TP1 pid=646936) WARNING 03-17 17:07:04 [kv_cache_utils.py:1056] Add 1 padding layers, may waste at most 4.35% KV cache memory
(Worker_TP1 pid=646936) INFO 03-17 17:07:04 [kv_cache_utils.py:826] Overriding num_gpu_blocks=0 with num_gpu_blocks_override=512
(Worker_TP0 pid=646817) WARNING 03-17 17:07:04 [kv_cache_utils.py:1056] Add 1 padding layers, may waste at most 4.35% KV cache memory
(Worker_TP0 pid=646817) INFO 03-17 17:07:04 [kv_cache_utils.py:826] Overriding num_gpu_blocks=0 with num_gpu_blocks_override=512
(Worker_TP1 pid=646936) INFO 03-17 17:07:04 [gpu_model_runner.py:5632] Profiling CUDA graph memory: PIECEWISE=51 (largest=512), FULL=35 (largest=256)
(Worker_TP0 pid=646817) INFO 03-17 17:07:04 [gpu_model_runner.py:5632] Profiling CUDA graph memory: PIECEWISE=51 (largest=512), FULL=35 (largest=256)
(Worker_TP1 pid=646936) INFO 03-17 17:07:05 [custom_all_reduce.py:216] Registering 212 cuda graph addresses
(Worker_TP0 pid=646817) INFO 03-17 17:07:05 [custom_all_reduce.py:216] Registering 212 cuda graph addresses
(Worker_TP1 pid=646936) INFO 03-17 17:07:06 [gpu_model_runner.py:5711] Estimated CUDA graph memory: 4.30 GiB total
(Worker_TP0 pid=646817) INFO 03-17 17:07:06 [gpu_model_runner.py:5711] Estimated CUDA graph memory: 4.38 GiB total
(Worker_TP1 pid=646936) INFO 03-17 17:07:06 [gpu_worker.py:468] CUDA graph memory profiling is enabled (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). This will become the default in v0.19. The current --gpu-memory-utilization=0.9000 is equivalent to --gpu-memory-utilization=0.8033 without CUDA graph memory profiling. To maintain the same effective KV cache size as before, increase --gpu-memory-utilization to 0.9967.
(Worker_TP0 pid=646817) INFO 03-17 17:07:06 [gpu_worker.py:452] Available KV cache memory: 25.46 GiB
(Worker_TP0 pid=646817) INFO 03-17 17:07:06 [gpu_worker.py:468] CUDA graph memory profiling is enabled (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). This will become the default in v0.19. The current --gpu-memory-utilization=0.9000 is equivalent to --gpu-memory-utilization=0.8015 without CUDA graph memory profiling. To maintain the same effective KV cache size as before, increase --gpu-memory-utilization to 0.9985.
(EngineCore pid=646534) WARNING 03-17 17:07:06 [kv_cache_utils.py:1056] Add 1 padding layers, may waste at most 4.35% KV cache memory
(EngineCore pid=646534) INFO 03-17 17:07:06 [kv_cache_utils.py:1316] GPU KV cache size: 3,553,776 tokens
(EngineCore pid=646534) INFO 03-17 17:07:06 [kv_cache_utils.py:1321] Maximum concurrency for 262,144 tokens per request: 63.54x
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [00:04<00:00, 12.56it/s]
Capturing CUDA graphs (decode, FULL):  97%|████████████████████████████████████████████████████████████████████████████████████████████████▏  | 34/35 [00:03<00:00, 12.69it/s](Worker_TP1 pid=646936) INFO 03-17 17:07:14 [custom_all_reduce.py:216] Registering 4558 cuda graph addresses
Capturing CUDA graphs (decode, FULL): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:03<00:00,  9.93it/s]
(Worker_TP0 pid=646817) INFO 03-17 17:07:14 [custom_all_reduce.py:216] Registering 4558 cuda graph addresses
(Worker_TP1 pid=646936) INFO 03-17 17:07:15 [gpu_worker.py:614] CUDA graph pool memory: 4.68 GiB (actual), 4.3 GiB (estimated), difference: 0.38 GiB (8.0%).
(Worker_TP0 pid=646817) INFO 03-17 17:07:15 [gpu_model_runner.py:5771] Graph capturing finished in 8 secs, took 4.68 GiB
(Worker_TP0 pid=646817) INFO 03-17 17:07:15 [gpu_worker.py:614] CUDA graph pool memory: 4.68 GiB (actual), 4.38 GiB (estimated), difference: 0.29 GiB (6.3%).
(EngineCore pid=646534) INFO 03-17 17:07:15 [core.py:281] init engine (profile, create kv cache, warmup model) took 19.46 seconds
(EngineCore pid=646534) INFO 03-17 17:07:17 [vllm.py:754] Asynchronous scheduling is enabled.
(EngineCore pid=646534) INFO 03-17 17:07:17 [compilation.py:289] Enabled custom fusions: act_quant
(APIServer pid=645740) INFO 03-17 17:07:17 [api_server.py:576] Supported tasks: ['generate']
(APIServer pid=645740) INFO 03-17 17:07:22 [hf.py:320] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
(APIServer pid=645740) INFO 03-17 17:07:22 [api_server.py:580] Starting vLLM server on http://0.0.0.0:8000
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:37] Available routes are:
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /openapi.json, Methods: HEAD, GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /docs, Methods: HEAD, GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /docs/oauth2-redirect, Methods: HEAD, GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /redoc, Methods: HEAD, GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /tokenize, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /detokenize, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /load, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /version, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /health, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /metrics, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/models, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /ping, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /ping, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /invocations, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/chat/completions, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/responses, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/responses/{response_id}, Methods: GET
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/responses/{response_id}/cancel, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/completions, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/messages, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/messages/count_tokens, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /inference/v1/generate, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /scale_elastic_ep, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /is_scaling_elastic_ep, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/chat/completions/render, Methods: POST
(APIServer pid=645740) INFO 03-17 17:07:22 [launcher.py:46] Route: /v1/completions/render, Methods: POST
(APIServer pid=645740) INFO:     Started server process [645740]
(APIServer pid=645740) INFO:     Waiting for application startup.
(APIServer pid=645740) INFO:     Application startup complete.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: xjx <493337577@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces padding for the N-dimension to meet alignment requirements for the Marlin FP4 kernel, which is a necessary fix. The core logic changes in vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py appear to correctly implement this padding. However, I've identified a critical bug in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py due to an incorrect getattr call that will lead to a runtime error. I've also pointed out a typo in a function name that should be corrected for code clarity.

N = marlin_moe_intermediate_size(w1, w2)
N = marlin_moe_intermediate_size(w1, w2, layer)
w13_num_shards = 2 if activation.is_gated else 1
w13_size_n = getattr(layer, "marlin_moe_w13_size_n", w13_num_shards, *N)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line will cause a TypeError at runtime because it attempts to unpack an integer N with *N. Additionally, layer can be None (e.g., when called from batched_fused_marlin_moe), which would cause an AttributeError on getattr. The logic should safely retrieve marlin_moe_w13_size_n from the layer if it exists, and fall back to the computed value otherwise.

Suggested change
w13_size_n = getattr(layer, "marlin_moe_w13_size_n", w13_num_shards, *N)
if layer and hasattr(layer, "marlin_moe_w13_size_n"):
w13_size_n = layer.marlin_moe_w13_size_n
else:
w13_size_n = w13_num_shards * N


# WEIGHT SCALES
# Permute scales
def premute_scales(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a typo in the function name. premute_scales should be permute_scales. Please also update the call sites at lines 377 and 378.

Suggested change
def premute_scales(
def permute_scales(

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 17, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @flutist.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 17, 2026
flutist added 3 commits March 18, 2026 10:09
Signed-off-by: xjx <493337577@qq.com>
…alignment' into fix_marlin_fp4_kernel_dimension_alignment
Signed-off-by: xjx <493337577@qq.com>
@mergify mergify bot removed the needs-rebase label Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant