fix(nvfp4): make process_weights_after_loading hot-reload-safe via alias-when-same-shape#25190
Merged
Merged
Conversation
PR sgl-project#25107 introduced del's of input_scale / weight_scale / weight_scale_2 (and MoE counterparts) after deriving alpha, input_scale_inv, weight_scale_interleaved, etc. That broke test_update_weights_from_disk_blackwell.py: on the second invocation from update_weights_from_disk, the deleted source scales cannot be re-read by process_weights_after_loading. Add a per-layer _weights_postprocessed flag, initialized False in create_weights and set True at every exit of process_weights_after_loading. The function early-returns when the flag is True, so update_weights_from_disk -> load_weights -> process_weights_after_loading is a no-op on the second pass. Contract: a post-processed layer's derived tensors (weight, weight_scale_interleaved, alpha, blockscale_swizzled, g*_alphas, ...) are the source of truth. To refresh them, the caller must either rebuild the layer or write directly into those derived slots -- re-loading raw source scales via update_weights_from_disk silently drops them (no param slot to fill) and is only correct when the new checkpoint shares the original's quant scales (typical for RL weight- only updates). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d guard The flag name and commit/PR message already explain the contract. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Moving the flag-set to immediately after the early-return guard covers all early-return paths inside process_weights_after_loading uniformly (TRTLLM branch returns; CUTLASS branch falls through) without duplicating the assignment. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
Author
|
/tag-and-rerun-ci |
Collaborator
|
/rerun-test test_update_weights_from_disk_blackwell.py |
Contributor
|
🚀 |
The flag-based fix (early-return on second process_weights_after_loading)
prevented the AttributeError crash but left the model producing garbage
after update_weights_from_disk: the loader refills layer.weight with raw
bytes, the early-return skips re-padding/re-shuffling, and the GEMM ends
up reading raw rows through a kernel expecting shuffled rows.
Instead, drop the flag and the del's entirely; let
process_weights_after_loading re-derive every call. To preserve the
memory win, introduce alias_or_bind_derived_param: when the derived
tensor has the same shape & dtype as the source Parameter, write the
derived data into the source's storage in place and register the
derived attribute name as an alias of the source Parameter. The two
names share one buffer, so:
- apply() reads via the derived name (weight_scale_interleaved /
*_blockscale_swizzled) and gets the post-processed bytes.
- update_weights_from_disk refills the source slot (weight_scale /
w*_weight_scale), and the next process_weights_after_loading call
re-derives in place.
- Peak GPU memory matches the source size, recovering the ~15 GiB/rank
Kimi-K2.5 NVFP4 TP=4 savings without breaking hot reload.
When shapes diverge (genuinely-needed padding), fall back to allocating
a separate Parameter for the derived name -- the savings are 0 in that
case, but correctness is preserved on any NVFP4 model.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…lars
input_scale, weight_scale_2 (Linear) and their MoE per-expert
counterparts are small but non-empty source params; the derived
alpha / input_scale_inv / g{1,2}_alphas / w{13,2}_input_scale_quant
are typically scalars or smaller-rank tensors. broadcast_to lets
us fill the source storage with the broadcast-replicated derived
value and alias the derived name, sharing one buffer per pair.
When the shapes are not broadcast-compatible (e.g. MoE g_alphas
[num_local_experts] vs gated weight_scale_2 [num_local_experts, 2]),
the helper falls through to copy_or_rebind_param so correctness is
preserved.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
alias_or_bind_derived_param's docstring already describes the source/derived/loader contract; remove the duplicate paragraphs at each call site. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…shape Aliasing alpha/input_scale_inv into the [N_partitions] source slot broke fused-QKV linears: a downstream call tries to view the scale as [1] and hits "shape '[1]' is invalid for input of size 3". The per-tensor savings are negligible (~bytes/linear), so revert these scalar pairs to copy_or_rebind_param. Large-tensor aliasing (weight_scale, *_blockscale_swizzled) stays. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
Fridge003
pushed a commit
that referenced
this pull request
May 14, 2026
…ias-when-same-shape (#25190) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
PR #25107 freed unused NVFP4 source scales (
input_scale/weight_scale/weight_scale_2and their MoE counterparts) insideprocess_weights_after_loadingto reclaim ~15 GiB/rank on Kimi-K2.5 NVFP4 TP=4. That broketest_update_weights_from_disk_blackwell.py::TestServerUpdateWeightsFromDiskNVFP4:alpha/input_scale_inv/weight_scale_interleaved, mutatelayer.weightin place (TRTLLMshuffle_matrix_apermutes rows, CUTLASSpad_nvfp4_weightpads N/K), thendelthe source scales./update_weights_from_disk→model.load_weights(...):layer.weightfrom safetensors with raw bytes (overwriting the previously-shuffled/padded representation).process_weights_after_loadingcall:layer.input_scale.max()→AttributeError. Even with a guard, the new rawlayer.weightmismatches the cached shuffled-layoutweight_scale_interleaved/alpha, so the GEMM produces garbage tokens.Neither a full revert (loses ~15 GiB/rank) nor a
_weights_postprocessedskip-on-second-pass flag (silently broken outputs after hot reload) is acceptable.Approach: alias-when-same-shape
Drop the
dels. Letprocess_weights_after_loadingrun on every call. Introducealias_or_bind_derived_paraminpython/sglang/srt/layers/utils/common.py: when the derived tensor has the same shape & dtype as the sourceParameter, write the derived bytes into the source's storage in place and register the derived attribute name as an alias of the sourceParameter. Two attribute names then share one underlying buffer.For each call:
apply()reads via the derived name (weight_scale_interleaved,*_blockscale_swizzled) and sees the post-processed bytes.update_weights_from_diskrefills the source slot (still registered under its original name) — the nextprocess_weights_after_loadingoverwrites the same buffer with the re-derived bytes.When shapes diverge (genuine padding required for non-aligned dims), fall back to allocating a separate Parameter via
copy_or_rebind_param— no memory savings in that case, but correctness preserved.For Kimi-K2.5 NVFP4 dims (H=7168, moe_intermediate=2048, group_size=16, etc.), the relevant pairs are all byte-for-byte same-size:
weight_scale↔weight_scale_interleaved(CUTLASS + TRTLLM paths)w13_weight_scale↔w13_blockscale_swizzled,w2_weight_scale↔w2_blockscale_swizzledlayer.weightis left untouched semantically (copy_or_rebind_paramcontinues to do in-place copy when shape matches), so the in-place shuffle/pad runs on every hot reload too.The TRTLLM-branch MoE additionally aliases the unused
*_blockscale_swizzledplaceholders to the in-place-shuffledw*_weight_scaleafteralign_fp4_moe_weights_for_flashinfer_trtllm, freeing the placeholders allocated increate_weights.What stayed scalar (not aliased)
alphaandinput_scale_inv(Linear) andg{1,2}_alphas/w{13,2}_input_scale_quant(MoE) remain plaincopy_or_rebind_paramParameters. An earlier attempt to alias them into the[N_partitions]source slot broke fused-QKV linears: a downstream call tries to view the scale as[1]and hitsRuntimeError: shape '[1]' is invalid for input of size 3. Per-linear savings would have been negligible anyway (~bytes); reverted in commitbf0b436.Modifications
python/sglang/srt/layers/utils/common.py: addalias_or_bind_derived_param(module, source_name, derived_name, derived_value).python/sglang/srt/layers/quantization/modelopt_quant.py:ModelOptFp4LinearMethod.process_weights_after_loading: replacecopy_or_rebind_param(layer, 'weight_scale_interleaved', ...)+del layer.weight_scalewithalias_or_bind_derived_param(layer, 'weight_scale', 'weight_scale_interleaved', ...)on both the TRTLLM and CUTLASS branches. Remove thedel layer.input_scale, layer.weight_scale_2line.ModelOptNvFp4FusedMoEMethod.process_weights_after_loading:del layer.w13_input_scale, layer.w2_input_scale.del layer.w13_blockscale_swizzled, layer.w2_blockscale_swizzledwith two lines aliasing the swizzled-name slots tow13_weight_scale/w2_weight_scale.copy_or_rebind_param(layer, '*_blockscale_swizzled', ...)+ trailingdel layer.w13_weight_scale, layer.w2_weight_scalewithalias_or_bind_derived_param(layer, 'w13_weight_scale', 'w13_blockscale_swizzled', ...)/ same for w2.ModelOptFp8LinearMethodandModelOptFp8MoEMethodare unchanged.Accuracy / hot-reload tests
test/registered/rl/test_update_weights_from_disk_blackwell.py::TestServerUpdateWeightsFromDiskNVFP4:nvidia/Qwen3-30B-A3B-NVFP4from disk twice with TP=4 + FlashInfer-TRTLLM FP4 GEMM + FlashInfer-TRTLLM-routed MoE.atol=1e-4.Verified on a 4× GB300 node:
Speed / memory
No latency impact — the
process_weights_after_loadingpath is one-shot at startup (and once perupdate_weights_from_diskcall). The same ~15 GiB/rank peak-memory win from PR #25107 is preserved on Kimi-K2.5 NVFP4 TP=4 (all relevant pairs hit the alias path; no separate buffer is allocated for the derived view).Checklist
test_update_weights_from_disk_blackwell.py.)🤖 Generated with Claude Code