Skip to content

fix(nvfp4): make process_weights_after_loading hot-reload-safe via alias-when-same-shape#25190

Merged
ch-wan merged 7 commits into
sgl-project:mainfrom
ch-wan:cw/nvfp4-postprocessed-flag
May 13, 2026
Merged

fix(nvfp4): make process_weights_after_loading hot-reload-safe via alias-when-same-shape#25190
ch-wan merged 7 commits into
sgl-project:mainfrom
ch-wan:cw/nvfp4-postprocessed-flag

Conversation

@ch-wan
Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan commented May 13, 2026

Motivation

PR #25107 freed unused NVFP4 source scales (input_scale / weight_scale / weight_scale_2 and their MoE counterparts) inside process_weights_after_loading to reclaim ~15 GiB/rank on Kimi-K2.5 NVFP4 TP=4. That broke test_update_weights_from_disk_blackwell.py::TestServerUpdateWeightsFromDiskNVFP4:

  • 1st invocation: derive alpha / input_scale_inv / weight_scale_interleaved, mutate layer.weight in place (TRTLLM shuffle_matrix_a permutes rows, CUTLASS pad_nvfp4_weight pads N/K), then del the source scales.
  • /update_weights_from_diskmodel.load_weights(...):
    • Refills layer.weight from safetensors with raw bytes (overwriting the previously-shuffled/padded representation).
    • Can't refill the deleted source-scale slots.
  • 2nd process_weights_after_loading call: layer.input_scale.max()AttributeError. Even with a guard, the new raw layer.weight mismatches the cached shuffled-layout weight_scale_interleaved / alpha, so the GEMM produces garbage tokens.

Neither a full revert (loses ~15 GiB/rank) nor a _weights_postprocessed skip-on-second-pass flag (silently broken outputs after hot reload) is acceptable.

Approach: alias-when-same-shape

Drop the dels. Let process_weights_after_loading run on every call. Introduce alias_or_bind_derived_param in python/sglang/srt/layers/utils/common.py: when the derived tensor has the same shape & dtype as the source Parameter, write the derived bytes into the source's storage in place and register the derived attribute name as an alias of the source Parameter. Two attribute names then share one underlying buffer.

For each call:

  • apply() reads via the derived name (weight_scale_interleaved, *_blockscale_swizzled) and sees the post-processed bytes.
  • update_weights_from_disk refills the source slot (still registered under its original name) — the next process_weights_after_loading overwrites the same buffer with the re-derived bytes.
  • Peak GPU memory is the source size, not source + derived.

When shapes diverge (genuine padding required for non-aligned dims), fall back to allocating a separate Parameter via copy_or_rebind_param — no memory savings in that case, but correctness preserved.

For Kimi-K2.5 NVFP4 dims (H=7168, moe_intermediate=2048, group_size=16, etc.), the relevant pairs are all byte-for-byte same-size:

  • Linear: weight_scaleweight_scale_interleaved (CUTLASS + TRTLLM paths)
  • MoE: w13_weight_scalew13_blockscale_swizzled, w2_weight_scalew2_blockscale_swizzled

layer.weight is left untouched semantically (copy_or_rebind_param continues to do in-place copy when shape matches), so the in-place shuffle/pad runs on every hot reload too.

The TRTLLM-branch MoE additionally aliases the unused *_blockscale_swizzled placeholders to the in-place-shuffled w*_weight_scale after align_fp4_moe_weights_for_flashinfer_trtllm, freeing the placeholders allocated in create_weights.

What stayed scalar (not aliased)

alpha and input_scale_inv (Linear) and g{1,2}_alphas / w{13,2}_input_scale_quant (MoE) remain plain copy_or_rebind_param Parameters. An earlier attempt to alias them into the [N_partitions] source slot broke fused-QKV linears: a downstream call tries to view the scale as [1] and hits RuntimeError: shape '[1]' is invalid for input of size 3. Per-linear savings would have been negligible anyway (~bytes); reverted in commit bf0b436.

Modifications

python/sglang/srt/layers/utils/common.py: add alias_or_bind_derived_param(module, source_name, derived_name, derived_value).

python/sglang/srt/layers/quantization/modelopt_quant.py:

  • ModelOptFp4LinearMethod.process_weights_after_loading: replace copy_or_rebind_param(layer, 'weight_scale_interleaved', ...) + del layer.weight_scale with alias_or_bind_derived_param(layer, 'weight_scale', 'weight_scale_interleaved', ...) on both the TRTLLM and CUTLASS branches. Remove the del layer.input_scale, layer.weight_scale_2 line.
  • ModelOptNvFp4FusedMoEMethod.process_weights_after_loading:
    • Remove del layer.w13_input_scale, layer.w2_input_scale.
    • TRTLLM branch: replace del layer.w13_blockscale_swizzled, layer.w2_blockscale_swizzled with two lines aliasing the swizzled-name slots to w13_weight_scale / w2_weight_scale.
    • non-TRTLLM (CUTLASS / CuteDSL) branch: replace the two copy_or_rebind_param(layer, '*_blockscale_swizzled', ...) + trailing del layer.w13_weight_scale, layer.w2_weight_scale with alias_or_bind_derived_param(layer, 'w13_weight_scale', 'w13_blockscale_swizzled', ...) / same for w2.

ModelOptFp8LinearMethod and ModelOptFp8MoEMethod are unchanged.

Accuracy / hot-reload tests

test/registered/rl/test_update_weights_from_disk_blackwell.py::TestServerUpdateWeightsFromDiskNVFP4:

  • Reloads nvidia/Qwen3-30B-A3B-NVFP4 from disk twice with TP=4 + FlashInfer-TRTLLM FP4 GEMM + FlashInfer-TRTLLM-routed MoE.
  • Asserts decode logprobs unchanged within atol=1e-4.

Verified on a 4× GB300 node:

Ran 1 test in 290.006s

OK

Speed / memory

No latency impact — the process_weights_after_loading path is one-shot at startup (and once per update_weights_from_disk call). The same ~15 GiB/rank peak-memory win from PR #25107 is preserved on Kimi-K2.5 NVFP4 TP=4 (all relevant pairs hit the alias path; no separate buffer is allocated for the derived view).

Checklist

🤖 Generated with Claude Code

ch-wan and others added 3 commits May 13, 2026 14:47
PR sgl-project#25107 introduced del's of input_scale / weight_scale /
weight_scale_2 (and MoE counterparts) after deriving alpha,
input_scale_inv, weight_scale_interleaved, etc. That broke
test_update_weights_from_disk_blackwell.py: on the second invocation
from update_weights_from_disk, the deleted source scales cannot be
re-read by process_weights_after_loading.

Add a per-layer _weights_postprocessed flag, initialized False in
create_weights and set True at every exit of
process_weights_after_loading. The function early-returns when the
flag is True, so update_weights_from_disk -> load_weights ->
process_weights_after_loading is a no-op on the second pass.

Contract: a post-processed layer's derived tensors (weight,
weight_scale_interleaved, alpha, blockscale_swizzled, g*_alphas, ...)
are the source of truth. To refresh them, the caller must either
rebuild the layer or write directly into those derived slots --
re-loading raw source scales via update_weights_from_disk silently
drops them (no param slot to fill) and is only correct when the new
checkpoint shares the original's quant scales (typical for RL weight-
only updates).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d guard

The flag name and commit/PR message already explain the contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Moving the flag-set to immediately after the early-return guard covers
all early-return paths inside process_weights_after_loading uniformly
(TRTLLM branch returns; CUTLASS branch falls through) without
duplicating the assignment.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ch-wan
Copy link
Copy Markdown
Collaborator Author

ch-wan commented May 13, 2026

/tag-and-rerun-ci

@alisonshao
Copy link
Copy Markdown
Collaborator

/rerun-test test_update_weights_from_disk_blackwell.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 13, 2026

🚀 4-gpu-b200 (1 test): ❌ View workflow run

cd test/ && python3 registered/rl/test_update_weights_from_disk_blackwell.py

The flag-based fix (early-return on second process_weights_after_loading)
prevented the AttributeError crash but left the model producing garbage
after update_weights_from_disk: the loader refills layer.weight with raw
bytes, the early-return skips re-padding/re-shuffling, and the GEMM ends
up reading raw rows through a kernel expecting shuffled rows.

Instead, drop the flag and the del's entirely; let
process_weights_after_loading re-derive every call. To preserve the
memory win, introduce alias_or_bind_derived_param: when the derived
tensor has the same shape & dtype as the source Parameter, write the
derived data into the source's storage in place and register the
derived attribute name as an alias of the source Parameter. The two
names share one buffer, so:

  - apply() reads via the derived name (weight_scale_interleaved /
    *_blockscale_swizzled) and gets the post-processed bytes.
  - update_weights_from_disk refills the source slot (weight_scale /
    w*_weight_scale), and the next process_weights_after_loading call
    re-derives in place.
  - Peak GPU memory matches the source size, recovering the ~15 GiB/rank
    Kimi-K2.5 NVFP4 TP=4 savings without breaking hot reload.

When shapes diverge (genuinely-needed padding), fall back to allocating
a separate Parameter for the derived name -- the savings are 0 in that
case, but correctness is preserved on any NVFP4 model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ch-wan and others added 3 commits May 13, 2026 16:12
…lars

input_scale, weight_scale_2 (Linear) and their MoE per-expert
counterparts are small but non-empty source params; the derived
alpha / input_scale_inv / g{1,2}_alphas / w{13,2}_input_scale_quant
are typically scalars or smaller-rank tensors. broadcast_to lets
us fill the source storage with the broadcast-replicated derived
value and alias the derived name, sharing one buffer per pair.

When the shapes are not broadcast-compatible (e.g. MoE g_alphas
[num_local_experts] vs gated weight_scale_2 [num_local_experts, 2]),
the helper falls through to copy_or_rebind_param so correctness is
preserved.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
alias_or_bind_derived_param's docstring already describes the
source/derived/loader contract; remove the duplicate paragraphs at
each call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…shape

Aliasing alpha/input_scale_inv into the [N_partitions] source slot
broke fused-QKV linears: a downstream call tries to view the scale as
[1] and hits "shape '[1]' is invalid for input of size 3". The per-tensor
savings are negligible (~bytes/linear), so revert these scalar pairs
to copy_or_rebind_param. Large-tensor aliasing (weight_scale,
*_blockscale_swizzled) stays.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ch-wan ch-wan changed the title fix(nvfp4): make process_weights_after_loading idempotent via _weights_postprocessed flag fix(nvfp4): make process_weights_after_loading hot-reload-safe via alias-when-same-shape May 13, 2026
@ch-wan ch-wan merged commit 6c0633b into sgl-project:main May 13, 2026
77 of 104 checks passed
Fridge003 pushed a commit that referenced this pull request May 14, 2026
…ias-when-same-shape (#25190)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants