Skip to content

Fix nvfp4 weight update#18085

Merged
Fridge003 merged 17 commits intosgl-project:mainfrom
zianglih:nvfp4-weight-sync
Feb 27, 2026
Merged

Fix nvfp4 weight update#18085
Fridge003 merged 17 commits intosgl-project:mainfrom
zianglih:nvfp4-weight-sync

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Feb 2, 2026

Motivation

@HumansAnd

The existing nvfp4 /update_weights_from_disk endpoint does not work. This PR fixes it.

This feature is a pre-requisite of nvfp4 RL in miles: radixark/miles#546

Modifications

Introduce a copy_or_rebind_param for in-place weight update to keep CUDA graph stable.

Accuracy Tests

Testing on nvidia/Qwen3-30B-A3B-NVFP4

hf download nvidia/Qwen3-30B-A3B-NVFP4 --local-dir /data/models/Qwen3-30B-A3B-NVFP4
python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 &
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
[2026-02-02 01:12:24] Start update_weights. Load format=auto
[2026-02-02 01:12:24] Update engine weights online from disk begin. avail mem=22.72 GB
[2026-02-02 01:12:24] Using ModelOptModelLoader due to ModelOpt quantization config.
[2026-02-02 01:12:24] Beginning to load weights
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.35s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:03,  1.88s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:06<00:02,  2.18s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.32s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.17s/it]

[2026-02-02 01:12:36] Update weights end.
[2026-02-02 01:12:36] Cache flushed successfully!
[2026-02-02 01:12:36] INFO:     127.0.0.1:51352 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
{"success":true,"message":"Succeeded to update model weights.","num_paused_requests":0}
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
# Trial 1
Accuracy: 0.938
Invalid: 0.000
Latency: 14.080 s
Output throughput: 10445.017 token/s
# Trial 2
Accuracy: 0.935
Invalid: 0.000
Latency: 13.962 s
Output throughput: 10553.795 token/s
# Trial 3
Accuracy: 0.940
Invalid: 0.000
Latency: 13.865 s
Output throughput: 10693.914 token/s

Testing on PTQ nvfp4 checkpoint from radixark/miles#536

Note, the complete accuracy depends on #18012 .

python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-235B-A22B-NVFP4-PTQ-full --tp 8 &
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-235B-A22B-NVFP4-PTQ-full",
    "flush_cache": true,
    "abort_all_requests": false
  }'
# Trial 1
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.966
Invalid: 0.000
Latency: 26.196 s
Output throughput: 6776.847 token/s
# Trial 2
Accuracy: 0.972
Invalid: 0.000
Latency: 24.188 s
Output throughput: 7365.630 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

[2026-02-02 00:13:21] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1083, in update_weights_from_disk
    model = model_load_weights(self.model, iter)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1073, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 692, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1510, in process_weights_after_loading
    ("w13", layer.w13_weight_scale),
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1964, in __getattr__
    raise AttributeError(
AttributeError: 'FusedMoE' object has no attribute 'w13_weight_scale'. Did you mean: 'w13_weight_scale_2'?

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3063, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1103, in event_loop_overlap
    self.process_input_requests(recv_reqs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1322, in process_input_requests
    output = self._request_dispatcher(recv_req)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/utils.py", line 507, in __call__
    return fn(obj)
           ^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_update_weights_mixin.py", line 50, in update_weights_from_disk
    success, message = self.tp_worker.update_weights_from_disk(recv_req)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 93, in update_weights_from_disk
    success, message = self.model_runner.update_weights_from_disk(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1091, in update_weights_from_disk
    self.model = model_load_weights(self.model, iter)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1073, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 692, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1510, in process_weights_after_loading
    ("w13", layer.w13_weight_scale),
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1964, in __getattr__
    raise AttributeError(
AttributeError: 'FusedMoE' object has no attribute 'w13_weight_scale'. Did you mean: 'w13_weight_scale_2'?

[2026-02-02 00:13:21] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
@github-actions github-actions bot added the quant LLM Quantization label Feb 2, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue preventing the nvfp4 /update_weights_from_disk endpoint from functioning correctly. The core of the fix involves a refined approach to updating model parameters, ensuring that their underlying memory and identity remain consistent across updates. This is achieved by introducing a new helper function that intelligently copies or rebinds tensor data, which is vital for maintaining the integrity of CUDA graphs and enabling seamless hot reloading of quantized model weights. The changes enhance the robustness and reliability of online weight updates for nvfp4 models.

Highlights

  • Introduced _copy_or_rebind_param function: A new utility function was added to ensure stable parameter identities during weight updates, crucial for CUDA graph reuse and hot reloading.
  • Refactored parameter updates: Replaced direct Parameter assignments with calls to _copy_or_rebind_param for various weight and scale tensors (alpha, input_scale_inv, weight_scale_interleaved, g1_alphas, g2_alphas, w13_input_scale_quant, w2_input_scale_quant, gemm1_weights_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, gemm2_scales_fp4_shuffled, g1_scale_c, w13_blockscale_swizzled, w13_weight, w2_weight, w2_weight_scale, w2_blockscale_swizzled).
  • Preserved original weights for hot reload: Removed del statements for w2_weight, w2_weight_scale, w13_weight, and w13_weight_scale to ensure these original weights are retained, enabling the update_weights_from_disk functionality.
  • Optimized CutlassMoEParams initialization: Modified the initialization of cutlass_moe_params to reuse existing parameters if their properties match, enhancing stability and avoiding unnecessary re-initialization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the issue with updating nvfp4 weights. The introduction of the _copy_or_rebind_param helper function is a solid approach to manage parameter updates in a way that is friendly to both CUDA graphs and hot reloading. Its consistent application across the codebase for replacing direct Parameter assignments is well-executed. A crucial part of the fix is the removal of del statements that previously discarded original weights and scales, which is essential for enabling weight updates. The changes also improve robustness by handling different checkpoint formats and reusing existing parameter objects to enhance stability. The code quality is high, and the solution is well-aligned with the stated goal of the PR.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Feb 9, 2026

After recent refactoring the weight sync remains broken in main. After the merge the fix still works expected:

BEFORE

# Download official NVIDIA nvfp4 checkpoint
hf download nvidia/Qwen3-30B-A3B-NVFP4 --local-dir /data/models/Qwen3-30B-A3B-NVFP4

python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 &
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.89s/it]

[2026-02-09 04:47:53] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1092, in update_weights_from_disk
    model = model_load_weights(self.model, iter)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1082, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 695, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1481, in process_weights_after_loading
    layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
IndexError: too many indices for tensor of dimension 1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3091, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1108, in event_loop_overlap
    self.process_input_requests(recv_reqs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1327, in process_input_requests
    output = self._request_dispatcher(recv_req)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/utils.py", line 525, in __call__
    return fn(obj)
           ^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_update_weights_mixin.py", line 50, in update_weights_from_disk
    success, message = self.tp_worker.update_weights_from_disk(recv_req)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 93, in update_weights_from_disk
    success, message = self.model_runner.update_weights_from_disk(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1100, in update_weights_from_disk
    self.model = model_load_weights(self.model, iter)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1082, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 695, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1481, in process_weights_after_loading
    layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
IndexError: too many indices for tensor of dimension 1

[2026-02-09 04:47:53] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
curl: (52) Empty reply from server

AFTER

git clone -b nvfp4-weight-sync https://github.com/zianglih/sglang.git
cd sglang
pip install -e "python"
python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 &
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
[2026-02-09 04:52:47] Update weights end.
[2026-02-09 04:52:47] Cache flushed successfully!
[2026-02-09 04:52:47] INFO:     127.0.0.1:33176 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
{"success":true,"message":"Succeeded to update model weights.","num_paused_requests":0}
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.938
Invalid: 0.000
Latency: 15.219 s
Output throughput: 9659.017 token/s

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Feb 12, 2026

/tag-and-rerun-ci again

@zianglih
Copy link
Copy Markdown
Contributor Author

Explicitly testing --moe-runner-backend cutlass and --moe-runner-backend flashinfer_trtllm --quantization modelopt_fp4 as requested:

  • CUTLASS, WITH this PR:
root@nb-slc-102:/sgl-workspace/sglang# python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 --moe-runner-backend cutlass &
root@nb-slc-102:/sgl-workspace/sglang# curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
[2026-02-13 00:07:07] Start update_weights. Load format=auto
[2026-02-13 00:07:07] Update engine weights online from disk begin. avail mem=22.72 GB
[2026-02-13 00:07:07] Using ModelOptModelLoader due to ModelOpt quantization config.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:03,  1.20s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:03,  1.78s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:05<00:01,  1.98s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  2.05s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.94s/it]

[2026-02-13 00:07:19] Update weights end.
[2026-02-13 00:07:19] Cache flushed successfully!
[2026-02-13 00:07:19] INFO:     127.0.0.1:55394 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
{"success":true,"message":"Succeeded to update model weights.","num_paused_requests":0}
root@nb-slc-102:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.938
Invalid: 0.000
Latency: 15.151 s
Output throughput: 9702.472 token/s
  • trtllm-gen, WITHOUT this PR:
root@nb-slc-102:/sgl-workspace/sglang# python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 --moe-runner-backend flashinfer_trtllm &
[1] 35
root@nb-slc-102:/sgl-workspace/sglang# [2026-02-13 00:16:19] WARNING model_config.py:918: DeepGemm is enabled but the scale_fmt of checkpoint is not ue8m0. This might cause accuracy degradation on Blackwell.
[2026-02-13 00:16:19] WARNING model_config.py:936: modelopt quantization is not fully optimized yet. The speed can be slower than non-quantized models.
[2026-02-13 00:16:19] INFO server_args.py:1806: Attention backend not specified. Use trtllm_mha backend by default.
[2026-02-13 00:16:19] WARNING server_args.py:1876: TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from None to 64.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/sgl-workspace/sglang/python/sglang/launch_server.py", line 32, in <module>
    server_args = prepare_server_args(sys.argv[1:])
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5502, in prepare_server_args
    return ServerArgs.from_cli_args(raw_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 4988, in from_cli_args
    return cls(**{attr: getattr(args, attr) for attr in attrs})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 329, in __init__
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 749, in __post_init__
    self._handle_moe_kernel_config()
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 2077, in _handle_moe_kernel_config
    assert self.quantization in [
           ^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Invalid quantization 'modelopt'. 
FlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'compressed-tensors', or bfloat16 (None).

root@nb-slc-102:/sgl-workspace/sglang# python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 --moe-runner-backend flashinfer_trtllm --quantization modelopt_fp4 &
root@nb-slc-102:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.936
Invalid: 0.000
Latency: 14.858 s
Output throughput: 9879.286 token/s
root@nb-slc-102:/sgl-workspace/sglang# curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
[2026-02-13 00:19:21] Start update_weights. Load format=auto
[2026-02-13 00:19:21] Update engine weights online from disk begin. avail mem=22.75 GB
[2026-02-13 00:19:21] Using ModelOptModelLoader due to ModelOpt quantization config.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.46s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:04,  2.00s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:06<00:02,  2.37s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:09<00:00,  2.51s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:09<00:00,  2.34s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.51s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:04,  2.05s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:06<00:02,  2.19s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.21s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.14s/it]

[2026-02-13 00:19:39] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1138, in update_weights_from_disk
    model = model_load_weights(self.model, iter)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1128, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 697, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1481, in process_weights_after_loading
    layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
IndexError: too many indices for tensor of dimension 1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3077, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1107, in event_loop_overlap
    self.process_input_requests(recv_reqs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1326, in process_input_requests
    output = self._request_dispatcher(recv_req)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/utils.py", line 525, in __call__
    return fn(obj)
           ^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler_update_weights_mixin.py", line 50, in update_weights_from_disk
    success, message = self.tp_worker.update_weights_from_disk(recv_req)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 93, in update_weights_from_disk
    success, message = self.model_runner.update_weights_from_disk(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1146, in update_weights_from_disk
    self.model = model_load_weights(self.model, iter)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1128, in model_load_weights
    loader.load_weights_and_postprocess(model, iter, target_device)
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 697, in load_weights_and_postprocess
    quant_method.process_weights_after_loading(module)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/modelopt_quant.py", line 1481, in process_weights_after_loading
    layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
IndexError: too many indices for tensor of dimension 1

[2026-02-13 00:19:39] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
curl: (52) Empty reply from server
  • trtllm-gen, WITH this PR:
root@nb-slc-102:/sgl-workspace/sglang# python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/models/Qwen3-30B-A3B-NVFP4 --moe-runner-backend flashinfer_trtllm --quantization modelopt_fp4 &
root@nb-slc-102:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.935
Invalid: 0.000
Latency: 13.528 s
Output throughput: 10880.799 token/s
root@nb-slc-102:/sgl-workspace/sglang# curl -sS http://localhost:30000/update_weights_from_disk   -H 'Content-Type: application/json'   -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
[2026-02-13 00:24:05] Start update_weights. Load format=auto
[2026-02-13 00:24:05] Update engine weights online from disk begin. avail mem=22.72 GB
[2026-02-13 00:24:05] Using ModelOptModelLoader due to ModelOpt quantization config.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:03,  1.17s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:03<00:03,  1.80s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:05<00:02,  2.04s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  2.12s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.99s/it]

[2026-02-13 00:24:14] Update weights end.
[2026-02-13 00:24:14] Cache flushed successfully!
[2026-02-13 00:24:14] INFO:     127.0.0.1:51104 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
{"success":true,"message":"Succeeded to update model weights.","num_paused_requests":0}
root@nb-slc-102:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.935
Invalid: 0.000
Latency: 12.660 s
Output throughput: 11626.977 token/s

This PR fixes both backends.

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Feb 13, 2026

Thanks for data.

Btw, flashinfer_cutlass and cutlass are actually the same original kernel, cutlass one is just ported before existence in flashinfer I believe... (I'm not sure if the cutlass one has been maintained recently, so it's performance may or may not be worse than through flashinfer API). (just an FYI)

@zhaochenyang20
Copy link
Copy Markdown
Collaborator

@guapisolo Jiajun will help to review this PR.

@zhaochenyang20
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@guapisolo
Copy link
Copy Markdown
Contributor

@guapisolo Jiajun will help to review this PR.

I've already review this PR but not sure whether it's good impl, so I call brayden for help.

@Fridge003
Copy link
Copy Markdown
Collaborator

@zianglih Please check the CI failures, thanks

auto-merge was automatically disabled February 25, 2026 01:55

Head branch was pushed to by a user without write access

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Feb 25, 2026

@Fridge003 previous OOM fp4 ci is fixed by 1d00042 (#18085 (comment)). Thanks!

@Fridge003 Fridge003 merged commit 9469ad0 into sgl-project:main Feb 27, 2026
178 of 193 checks passed
zianglih pushed a commit to zianglih/sglang that referenced this pull request Feb 28, 2026
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
@zianglih zianglih deleted the nvfp4-weight-sync branch April 6, 2026 08:08
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants