Skip to content

[RL] Refactor NVFP4 shuffling/swizzling to in-place replacement#22204

Open
zianglih wants to merge 6 commits intosgl-project:mainfrom
zianglih:nvfp4-shuffle-refactor
Open

[RL] Refactor NVFP4 shuffling/swizzling to in-place replacement#22204
zianglih wants to merge 6 commits intosgl-project:mainfrom
zianglih:nvfp4-shuffle-refactor

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 6, 2026

Motivation

@HumansAnd

#18085 is an earlier fix for nvfp4 weight update but it did not fix trtllm backend, since trtllm backend used *_weights_fp4_shuffled tensors, which requires broader refactoring for in-pace replacement.

This PR replaces all *_weights_fp4_shuffled with the original weight tensor and conducts swizzling/shuffling with in-place replacement.

Modifications

  • Rename and expand test_update_weights_from_disk_blackwell.py, now it covers both mxfp8 and nvfp4
  • Replaces all *_weights_fp4_shuffled with the original weight tensor and conducts swizzling/shuffling with in-place replacement

Accuracy Tests

gsm8k

python3 -m sglang.launch_server --kv-cache-dtype bf16 --model nvidia/Qwen3-30B-A3B-NVFP4
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.945
Invalid: 0.001
Latency: 8.517 s
Output throughput: 17180.418 token/s
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "nvidia/Qwen3-30B-A3B-NVFP4",
    "flush_cache": true,
    "abort_all_requests": false
  }'
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.942
Invalid: 0.001
Latency: 7.789 s
Output throughput: 18788.124 token/s

after the fix

python3 -m pytest -s -q test/registered/rl/test_update_weights_from_disk_mxfp8.py

============================= warnings summary =============================
../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
2 passed, 3 warnings, 6 subtests passed in 204.60s (0:03:24)
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

before the fix

root@B200-51:/sgl-workspace/sglang# python3 -m pytest -s -q test/registered/rl/test_update_weights_from_disk_mxfp8.py -k NVFP4
====================================== warnings summary ======================================
../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================== short test summary info ===================================
SUBFAILED(case_name='flashinfer_trtllm_nvfp4', flush_cache=True, abort_all_requests=False, model='nvidia/Qwen3-30B-A3B-NVFP4') test/registered/rl/test_update_weights_from_disk_mxfp8.py::TestServerUpdateWeightsFromDiskNVFP4::test_parameterized_update_weights_from_disk - requests.exceptions.ConnectionError: ('Connection aborted.', RemoteDisconnected('Remote e...
SUBFAILED(case_name='flashinfer_trtllm_nvfp4', flush_cache=False, abort_all_requests=False, model='nvidia/Qwen3-30B-A3B-NVFP4') test/registered/rl/test_update_weights_from_disk_mxfp8.py::TestServerUpdateWeightsFromDiskNVFP4::test_parameterized_update_weights_from_disk - requests.exceptions.ConnectionError: HTTPConnectionPool(host='127.0.0.1', port=21000): Ma...
2 failed, 1 passed, 1 deselected, 3 warnings in 66.06s (0:01:06)
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors weight management for FlashInfer TRT-LLM FP4 MoE by renaming shuffled weight parameters to standard names like w13_weight and utilizing in-place replacement utilities. Additionally, it refactors the weight update test suite into a base class to support both MXFP8 and NVFP4 models. Review feedback suggests using a more reliable attribute check to ensure weights are correctly prepared before access and removing redundant .contiguous() calls on tensors that are already contiguous.

# backend is flashinfer_trtllm
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# FlashInfer TRTLLM FP4 path
if self.enable_flashinfer_trtllm_moe:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing the condition from hasattr(layer, "gemm1_weights_fp4_shuffled") to self.enable_flashinfer_trtllm_moe is potentially unsafe. If the weight alignment/shuffling process was skipped during process_weights_after_loading (e.g., due to missing FlashInfer dependencies), layer.w13_weight will contain unshuffled weights and layer.g1_scale_c will not be present. This will lead to an AttributeError at line 1942 and incorrect kernel execution. It is safer to check for a FlashInfer-specific attribute like g1_scale_c to ensure the layer is properly prepared.

Suggested change
if self.enable_flashinfer_trtllm_moe:
if hasattr(layer, "g1_scale_c"):

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Apr 6, 2026

cc @yueming-yuan @Zhichenzzz

@ziang-and ziang-and force-pushed the nvfp4-shuffle-refactor branch from 877bb0c to 9816488 Compare April 7, 2026 23:42
@github-actions github-actions bot added the npu label Apr 7, 2026
@ziang-and ziang-and force-pushed the nvfp4-shuffle-refactor branch from 5c6944b to f7635bc Compare April 8, 2026 22:39
@ziang-and ziang-and force-pushed the nvfp4-shuffle-refactor branch from f7635bc to d638682 Compare April 10, 2026 05:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 npu quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants