[RL] Refactor NVFP4 shuffling/swizzling to in-place replacement#22204
[RL] Refactor NVFP4 shuffling/swizzling to in-place replacement#22204zianglih wants to merge 6 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors weight management for FlashInfer TRT-LLM FP4 MoE by renaming shuffled weight parameters to standard names like w13_weight and utilizing in-place replacement utilities. Additionally, it refactors the weight update test suite into a base class to support both MXFP8 and NVFP4 models. Review feedback suggests using a more reliable attribute check to ensure weights are correctly prepared before access and removing redundant .contiguous() calls on tensors that are already contiguous.
| # backend is flashinfer_trtllm | ||
| if hasattr(layer, "gemm1_weights_fp4_shuffled"): | ||
| # FlashInfer TRTLLM FP4 path | ||
| if self.enable_flashinfer_trtllm_moe: |
There was a problem hiding this comment.
Changing the condition from hasattr(layer, "gemm1_weights_fp4_shuffled") to self.enable_flashinfer_trtllm_moe is potentially unsafe. If the weight alignment/shuffling process was skipped during process_weights_after_loading (e.g., due to missing FlashInfer dependencies), layer.w13_weight will contain unshuffled weights and layer.g1_scale_c will not be present. This will lead to an AttributeError at line 1942 and incorrect kernel execution. It is safer to check for a FlashInfer-specific attribute like g1_scale_c to ensure the layer is properly prepared.
| if self.enable_flashinfer_trtllm_moe: | |
| if hasattr(layer, "g1_scale_c"): |
877bb0c to
9816488
Compare
5c6944b to
f7635bc
Compare
f7635bc to
d638682
Compare
Motivation
@HumansAnd
#18085 is an earlier fix for nvfp4 weight update but it did not fix trtllm backend, since trtllm backend used
*_weights_fp4_shuffledtensors, which requires broader refactoring for in-pace replacement.This PR replaces all
*_weights_fp4_shuffledwith the original weight tensor and conducts swizzling/shuffling with in-place replacement.Modifications
test_update_weights_from_disk_blackwell.py, now it covers both mxfp8 and nvfp4*_weights_fp4_shuffledwith the original weight tensor and conducts swizzling/shuffling with in-place replacementAccuracy Tests
gsm8k
after the fix
before the fix
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci