Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6f98251
reload model weight for fp8 rollout
AniZpZ Aug 26, 2025
fdd69af
fmt
AniZpZ Aug 26, 2025
d44f067
fix
AniZpZ Aug 26, 2025
55ef3d1
fix
AniZpZ Aug 26, 2025
ec05302
fix
AniZpZ Aug 26, 2025
1aebe87
Runnable
Hecate0821 Oct 23, 2025
5fe8cac
Fix
Hecate0821 Oct 23, 2025
22041c9
Add fast path
Hecate0821 Oct 23, 2025
b58f208
Fix
Hecate0821 Nov 4, 2025
d413773
Saparate verl and sglang
Hecate0821 Nov 6, 2025
0c0c604
Update
Hecate0821 Nov 6, 2025
e658b17
Use flashRL's approach
Hecate0821 Nov 6, 2025
1e58dc1
Follow falshRL sglang-patch
Hecate0821 Nov 7, 2025
c574672
Fix
Hecate0821 Nov 7, 2025
d92d248
Remove dimension adjustment
Hecate0821 Nov 7, 2025
869e08a
fp8 initial commit
Hecate0821 Nov 7, 2025
fd6a572
Fix
Hecate0821 Nov 8, 2025
7d5eaa0
debug stacck
Hecate0821 Nov 8, 2025
908c419
Draft
Hecate0821 Nov 8, 2025
50ffe46
Update
Hecate0821 Nov 8, 2025
5b0b99c
Clean comments
Hecate0821 Nov 8, 2025
771a082
fix: update weight_scale during fp8 reload
eternally-z Nov 19, 2025
67708bc
format
AniZpZ Nov 19, 2025
c57f3fb
upd
AniZpZ Nov 19, 2025
a93f206
Merge branch 'main' into quant_rollout
AniZpZ Nov 19, 2025
96e1771
upd
AniZpZ Nov 19, 2025
2f5ae9c
minor fix
AniZpZ Nov 21, 2025
114072f
minor
AniZpZ Nov 21, 2025
46ee442
Merge branch 'main' into quant_rollout
Wilboludriver Dec 2, 2025
30f0bc2
Merge branch 'main' into quant_rollout
Wilboludriver Dec 3, 2025
412d81e
Merge branch 'main' into quant_rollout
AniZpZ Dec 3, 2025
4e07bd5
minor fix
AniZpZ Dec 5, 2025
a7840d7
minor fix
AniZpZ Dec 5, 2025
408c0ee
minor
AniZpZ Dec 5, 2025
670dd43
minor fix
Wilboludriver Dec 5, 2025
7592a04
Merge branch 'main' into quant_rollout
AniZpZ Dec 8, 2025
7a569a9
fix:fix scale update with tp
eternally-z Dec 8, 2025
87cf419
minor fix
eternally-z Dec 9, 2025
0a2f20d
fmt
AniZpZ Dec 9, 2025
69fafaf
Merge branch 'main' into quant_rollout
ispobock Dec 9, 2025
6d4dae0
Merge branch 'main' into quant_rollout
AniZpZ Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
FLASH_RL = "flash_rl" # For RL training with quantized models
JAX = "jax"
REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"
Expand All @@ -46,6 +47,8 @@ class LoadConfig:
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
"flash_rl" will load weights with support for RL training
with quantized models, enabling efficient weight reloading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
Expand Down Expand Up @@ -78,6 +81,11 @@ class LoadConfig:
# ModelOpt configuration object
modelopt_config: Optional[ModelOptConfig] = None

# QuantizedRL-specific options (for FlashRL-style quantization)
rl_quant_profile: Optional[str] = (
None # Path to rollout quantization profile (e.g., /root/profile.7b.pt)
)

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
Expand Down
22 changes: 20 additions & 2 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,16 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
else:
# FIXME: This branch is needed to load deepseek v3 awq.
# However, we should fix this and avoid the branching here.
param.load_column_parallel_weight(loaded_weight)
# After QuantizedRL reload, params might still need tp_rank
try:
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
except TypeError:
# Fallback for parameters that don't accept additional args
param.load_column_parallel_weight(loaded_weight)

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
Expand Down Expand Up @@ -1360,7 +1369,16 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
else:
# `params` is defined in `vllm/model_executor/parameter.py`,
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
# However, after QuantizedRL reload, params might still need tp_rank
try:
param.load_row_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
except TypeError:
# Fallback for parameters that don't accept additional args
param.load_row_parallel_weight(loaded_weight)

def forward(self, input_, skip_all_reduce=False):
if self.input_is_parallel:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def load_model(self):
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
modelopt_config=modelopt_config,
rl_quant_profile=self.server_args.rl_quant_profile,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
Expand Down
Loading
Loading