diff --git a/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py new file mode 100644 index 000000000000..7d83c67138bd --- /dev/null +++ b/python/sglang/srt/models/deepseek_common/deepseek_weight_loader.py @@ -0,0 +1,657 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import concurrent.futures +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn +import tqdm +from transformers import PretrainedConfig + +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.environ import envs +from sglang.srt.layers import deep_gemm_wrapper +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + normalize_e4m3fn_to_e4m3fnuz, + quant_weight_ue8m0, +) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_loader.utils import ( + maybe_executor_submit, + should_async_load, + should_deepgemm_weight_requant_ue8m0, +) +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_common.utils import ( + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_hip, + _is_npu, + _use_aiter_gfx95, + awq_dequantize_func, + enable_nextn_moe_bf16_cast_to_fp8, +) +from sglang.srt.utils import bind_or_assign, get_bool_env_var, log_info_on_rank0 + +if _use_aiter_gfx95: + from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights + +logger = logging.getLogger(__name__) + +# Optional quantization for DeepSeek nvfp4 checkpoint +NVFP4_CKPT_FP8_ATTN_QUANT_MODULES = ["q_b_proj"] + + +class DeepseekV2WeightLoaderMixin: + """Mixin for loading weights in DeepSeek V2/V3 models.""" + + model: nn.Module + config: PretrainedConfig + quant_config: Optional[QuantizationConfig] + pp_group: GroupCoordinator + num_fused_shared_experts: int + + def do_load_weights( + self: nn.Module, + weights: Iterable[Tuple[str, torch.Tensor]], + is_nextn: bool = False, + ): + """Load model weights from checkpoint. + + Args: + weights: Iterable of (weight_name, weight_tensor) pairs + is_nextn: Whether loading NextN speculative decoding weights + """ + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + # compatible with old design + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + weights = self._maybe_quant_weights_to_fp8_ue8m0( + weights, NVFP4_CKPT_FP8_ATTN_QUANT_MODULES, is_nextn + ) + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + # Params for special naming rules in mixed-precision models, for example: + # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, + # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names = [ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + ] + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + use_async_loading = should_async_load(loaded_weight) + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue + else: + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight, shard_id), + ) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=( + param, + loaded_weight, + name, + ), + func_kwargs={ + "shard_id": shard_id, + "expert_id": expert_id, + }, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading embed_tokens if not first rank in pipeline parallelism + if ".embed_tokens." in name and not self.pp_group.is_first_rank: + continue + # Skip loading norm if not last rank in pipeline parallelism + if ".norm." in name and not self.pp_group.is_last_rank: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + + if q_a_proj_weight.shape == torch.Size( + [] + ) and kv_a_proj_weight.shape == torch.Size([]): + fused_weight = q_a_proj_weight + else: + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + + param_name = ( + name.replace( + "q_a_proj", "fused_qkv_a_proj_with_mqa" + ) + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", + "fused_qkv_a_proj_with_mqa", + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight), + ) + + # Wait for all tasks to complete and raise any exceptions. + for future in concurrent.futures.as_completed(futures): + future.result() + + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + + def post_load_weights( + self: nn.Module, + is_nextn: bool = False, + weight_names: Optional[Iterable[str]] = None, + ) -> None: + """Post-process weights after loading. + + Handles kv_b_proj weight processing including: + - AWQ dequantization + - FP8/INT8 requantization and block-wise to tensor-wise conversion + - Splitting weights into w_kc and w_vc components for MLA + + Args: + is_nextn: Whether processing NextN weights + weight_names: Optional list of loaded weight names to determine which layers to process + """ + if is_nextn: + layer_ids = [self.config.num_hidden_layers] + else: + if weight_names is None: + layer_ids = range(self.model.start_layer, self.model.end_layer) + else: + layer_ids = set() + for name in weight_names: + if "kv_b_proj" in name: + layer_id = int(name.split(".")[2]) + if layer_id < self.config.num_hidden_layers: + layer_ids.add(layer_id) + + for layer_id in layer_ids: + self_attn = ( + self.model.layers[layer_id].self_attn + if not is_nextn + else self.model.decoder.self_attn + ) + + if hasattr(self_attn.kv_b_proj, "qweight"): + # awq compatible, dequantize the weight if supported + awq_dequantize_f = awq_dequantize_func() + if awq_dequantize_f is not None: + w = awq_dequantize_f( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + raise ValueError( + "AWQ dequantize function is not supported for the current device" + ) + else: + w = self_attn.kv_b_proj.weight + + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + # Fix deepseek v3 blockwise bmm by using deep_gemm + use_deep_gemm_bmm = False + + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + # For mixed quantization (experts int4, linear fp8), use linear_fp8_config + selected_quant_config = getattr( + self.quant_config, "linear_fp8_config", None + ) + if selected_quant_config is None: + selected_quant_config = self.quant_config + weight_block_size = getattr( + selected_quant_config, "weight_block_size", None + ) + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") or hasattr( + self_attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + self_attn.kv_b_proj.weight_scale + if hasattr(self_attn.kv_b_proj, "weight_scale") + else self_attn.kv_b_proj.weight_scale_inv + ) + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=weight_scale, + input_scale=None, + ) + else: + weight = w + + # In multiple weight loading scenarios (e.g. RL), we need to inverse the scale of the weights after the requantization happened at the first loading. + if ( + should_deepgemm_weight_requant_ue8m0( + weight_block_size=getattr( + self.quant_config, "weight_block_size", None + ) + ) + and weight_scale.format_ue8m0 + ): + weight_scale = inverse_transform_scale_ue8m0( + weight_scale, mn=weight.shape[-2] + ) + + if ( + _is_cuda + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + ): + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + + if ( + _use_aiter_gfx95 + and self.quant_config is not None + and self.quant_config.get_name() == "quark" + ): + w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = ( + quark_post_load_weights(self_attn, w, "mxfp4") + ) + + if not use_deep_gemm_bmm: + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + w_vc = w_vc.contiguous().transpose(1, 2) + if _is_npu: + w_vc = w_vc.contiguous() + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = bind_or_assign( + self_attn.w_scale, self_attn.kv_b_proj.weight_scale + ) + if _is_hip: + self_attn.w_scale *= 2.0 + # TODO: remove this after adding FP8 support in bmm cpu kernel + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self_attn.w_kc = ( + self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale + ) + self_attn.w_vc = ( + self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale + ) + else: + num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self_attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + self_attn.w_scale_k = bind_or_assign( + self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + self_attn.w_scale_v = bind_or_assign( + self_attn.w_scale_v, ws_vc.contiguous() + ) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous() + ) + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) + self_attn.use_deep_gemm_bmm = True + + def _maybe_quant_weights_to_fp8_ue8m0( + self, weights, attn_quant_modules, is_nextn=False + ): + """Optionally quantize weights to FP8 UE8M0 format for DeepSeek nvfp4 checkpoints. + + Args: + weights: Iterable of (name, tensor) weight pairs + attn_quant_modules: List of attention module names to quantize + is_nextn: Whether processing NextN weights + + Returns: + List of (name, tensor) pairs with quantized weights + """ + partial_names = [] + nextn_layer_id = ( + 0 if self.config.num_hidden_layers == 1 else self.config.num_hidden_layers + ) + weights_dict = dict(weights) + weight_block_size = [128, 128] + + if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get(): + layer_ids = ( + list(range(self.config.num_hidden_layers)) + if not is_nextn + else [nextn_layer_id] + ) + for layer_id in layer_ids: + for stem in attn_quant_modules: + partial_names.append(f"model.layers.{layer_id}.self_attn.{stem}") + + if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + for expert_sub_name in [ + "shared_experts", + *[ + f"experts.{expert_id}" + for expert_id in range(self.config.n_routed_experts) + ], + ]: + for stem in [ + "gate_proj", + "up_proj", + "down_proj", + ]: + partial_names.append( + f"model.layers.{nextn_layer_id}.mlp.{expert_sub_name}.{stem}" + ) + + if len(partial_names) > 0: + for partial_name in tqdm.tqdm( + partial_names, + desc="quant weights to fp8 ue8m0", + ): + original_weight = weights_dict[f"{partial_name}.weight"] + out_w, out_s = quant_weight_ue8m0( + original_weight, weight_block_size=weight_block_size + ) + weights_dict[f"{partial_name}.weight"] = out_w + weights_dict[f"{partial_name}.weight_scale_inv"] = out_s + + if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + self._mark_nextn_moe_weights_as_ue8m0() + + return list(weights_dict.items()) + + def _mark_nextn_moe_weights_as_ue8m0(self): + """Mark NextN MoE weight scales as UE8M0 format to avoid requantization.""" + experts = self.model.decoder.mlp.experts + w13_scale = ( + experts.w13_weight_scale_inv + if hasattr(experts, "w13_weight_scale_inv") + else experts.w13_weight_scale + ) + w2_scale = ( + experts.w2_weight_scale_inv + if hasattr(experts, "w2_weight_scale_inv") + else experts.w2_weight_scale + ) + w13_scale.format_ue8m0 = True + w2_scale.format_ue8m0 = True diff --git a/python/sglang/srt/models/deepseek_common/utils.py b/python/sglang/srt/models/deepseek_common/utils.py index 6c78f5683890..514cc025c7f4 100644 --- a/python/sglang/srt/models/deepseek_common/utils.py +++ b/python/sglang/srt/models/deepseek_common/utils.py @@ -1,3 +1,19 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from sglang.srt.environ import envs +from sglang.srt.layers.moe.fused_moe_triton.layer import get_moe_runner_backend from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.utils import ( cpu_has_amx_support, @@ -19,5 +35,41 @@ _is_cpu = is_cpu() _device_sm = get_device_sm() _is_gfx95_supported = is_gfx95_supported() - _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported + + +def awq_dequantize_func(): + """ + Get the AWQ dequantize function for the current device + + Return: + - The AWQ dequantize function for the current device. + - None if the current device is not supported. + """ + if _is_cuda: + from sgl_kernel import awq_dequantize + + return awq_dequantize + elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) + + return awq_dequantize + elif _is_npu: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_decomposition as awq_dequantize, + ) + + return awq_dequantize + else: + return None + + +def enable_nextn_moe_bf16_cast_to_fp8(quant_config): + return ( + envs.SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE.get() + and quant_config is not None + and quant_config.get_name() == "modelopt_fp4" + and get_moe_runner_backend().is_deep_gemm() + ) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 24a4cd1db253..2fa3d2943858 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -46,11 +46,8 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.models.deepseek_v2 import ( - DeepseekV2DecoderLayer, - DeepseekV3ForCausalLM, - enable_nextn_moe_bf16_cast_to_fp8, -) +from sglang.srt.models.deepseek_common.utils import enable_nextn_moe_bf16_cast_to_fp8 +from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_npu diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2fb7480774d3..d91bc5024c67 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -17,7 +17,6 @@ """Inference-only DeepseekV2 model.""" from __future__ import annotations -import concurrent.futures import logging import os from contextlib import nullcontext @@ -25,7 +24,6 @@ import torch import torch.nn.functional as F -import tqdm from torch import nn from transformers import PretrainedConfig @@ -111,31 +109,14 @@ per_tensor_quant_mla_fp8, per_token_group_quant_mla_deep_gemm_masked_fp8, ) -from sglang.srt.layers.quantization.fp8_utils import ( - block_quant_dequant, - block_quant_to_tensor_quant, - channel_quant_to_tensor_quant, - inverse_transform_scale_ue8m0, - normalize_e4m3fn_to_e4m3fnuz, - quant_weight_ue8m0, -) -from sglang.srt.layers.quantization.int8_utils import ( - block_dequant as int8_block_dequant, -) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope_wrapper -from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_loader.utils import ( - maybe_executor_submit, - should_async_load, - should_deepgemm_weight_requant_ue8m0, -) -from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_common.attention_backend_handler import ( AttentionBackendRegistry, ) @@ -143,12 +124,14 @@ AttnForwardMethod, DeepseekMHAForwardMixin, ) +from sglang.srt.models.deepseek_common.deepseek_weight_loader import ( + DeepseekV2WeightLoaderMixin, +) from sglang.srt.models.deepseek_common.utils import ( _device_sm, _is_cpu, _is_cpu_amx_available, _is_cuda, - _is_fp8_fnuz, _is_gfx95_supported, _is_hip, _is_npu, @@ -161,7 +144,6 @@ BumpAllocator, LazyValue, add_prefix, - bind_or_assign, get_bool_env_var, is_non_idle_and_non_empty, is_nvidia_cublas_cu12_version_ge_12_9, @@ -180,7 +162,6 @@ fused_rms_fp8_group_quant, ) - from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( batched_gemm_afp4wfp4_pre_quant, fused_flatten_mxfp4_quant, @@ -193,16 +174,13 @@ ) if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8, dsv3_fused_a_gemm, dsv3_router_gemm + from sgl_kernel import bmm_fp8, dsv3_fused_a_gemm, dsv3_router_gemm elif _is_cpu and _is_cpu_amx_available: pass elif _is_hip: from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( decode_attention_fwd_grouped_rope, ) - from sglang.srt.layers.quantization.awq_triton import ( - awq_dequantize_triton as awq_dequantize, - ) elif _is_npu: from sglang.srt.hardware_backend.npu.modules.deepseek_v2_attention_mla_npu import ( forward_dsa_core_npu, @@ -212,9 +190,6 @@ forward_mla_core_npu, forward_mla_prepare_npu, ) - from sglang.srt.layers.quantization.awq_triton import ( - awq_dequantize_decomposition as awq_dequantize, - ) else: pass @@ -223,19 +198,6 @@ logger = logging.getLogger(__name__) -# Optional quantization for DeepSeek nvfp4 checkpoint -NVFP4_CKPT_FP8_ATTN_QUANT_MODULES = ["q_b_proj"] - - -def enable_nextn_moe_bf16_cast_to_fp8(quant_config): - return ( - envs.SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE.get() - and quant_config is not None - and quant_config.get_name() == "modelopt_fp4" - and get_moe_runner_backend().is_deep_gemm() - ) - - FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [ "fa3", "nsa", @@ -2750,7 +2712,7 @@ def forward( return hidden_states, aux_hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(nn.Module, DeepseekV2WeightLoaderMixin): # for quark model load packed_modules_mapping = {} @@ -2761,6 +2723,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + # for quark model load # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None self.fuse_qkv_a_proj = ( @@ -2781,6 +2744,7 @@ def __init__( self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) + if self.pp_group.is_last_rank: if self.pp_group.world_size == 1 and config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -2906,486 +2870,8 @@ def start_layer(self): def end_layer(self): return self.model.end_layer - def post_load_weights(self, is_nextn=False, weight_names=None): - - # Perform post-processing after loading weights - if is_nextn: - layer_ids = [self.config.num_hidden_layers] - else: - if weight_names is None: - layer_ids = range(self.model.start_layer, self.model.end_layer) - else: - layer_ids = set() - for name in weight_names: - if "kv_b_proj" in name: - layer_id = int(name.split(".")[2]) - if layer_id < self.config.num_hidden_layers: - layer_ids.add(layer_id) - - for layer_id in layer_ids: - self_attn = ( - self.model.layers[layer_id].self_attn - if not is_nextn - else self.model.decoder.self_attn - ) - if hasattr(self_attn.kv_b_proj, "qweight"): - # AWQ compatible - if _is_cuda or _is_hip or _is_npu: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - ).T - else: - w = awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T - else: - w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - # Fix deepseek v3 blockwise bmm by using deep_gemm - use_deep_gemm_bmm = False - - if w.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): - # For mixed quantization (experts int4, linear fp8), use linear_fp8_config - selected_quant_config = getattr( - self.quant_config, "linear_fp8_config", None - ) - if selected_quant_config is None: - selected_quant_config = self.quant_config - weight_block_size = getattr( - selected_quant_config, "weight_block_size", None - ) - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") or hasattr( - self_attn.kv_b_proj, "weight_scale" - ) - weight_scale = ( - self_attn.kv_b_proj.weight_scale - if hasattr(self_attn.kv_b_proj, "weight_scale") - else self_attn.kv_b_proj.weight_scale_inv - ) - if _is_fp8_fnuz: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=weight_scale, - input_scale=None, - ) - else: - weight = w - - # In multiple weight loading scenarios (e.g. RL), we need to inverse the scale of the weights after the requantization happened at the first loading. - if ( - should_deepgemm_weight_requant_ue8m0( - weight_block_size=getattr( - self.quant_config, "weight_block_size", None - ) - ) - and weight_scale.format_ue8m0 - ): - weight_scale = inverse_transform_scale_ue8m0( - weight_scale, mn=weight.shape[-2] - ) - - if ( - _is_cuda - and weight_block_size[0] == 128 - and weight_block_size[1] == 128 - ): - if ( - deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL - and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") - ): - block_scale = weight_scale - use_deep_gemm_bmm = True - else: - w = block_quant_dequant( - weight, - weight_scale, - weight_block_size, - torch.bfloat16, - ) - else: - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - self_attn.w_scale = scale - else: - if _is_fp8_fnuz: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, - weight_scale=self_attn.kv_b_proj.weight_scale, - input_scale=None, - ) - else: - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale - - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - self_attn.w_scale = scale - - if w.dtype == torch.int8: - if hasattr(self.quant_config, "weight_block_size"): - # block-wise int8 need it - weight_block_size = self.quant_config.weight_block_size - if weight_block_size is not None: - assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") - weight = w - weight_scale = self_attn.kv_b_proj.weight_scale_inv - w = int8_block_dequant( - weight, weight_scale, weight_block_size - ).to(torch.bfloat16) - else: - # channel-wise int8 need it - w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( - torch.bfloat16 - ) - - w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) - ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) - - if ( - _use_aiter_gfx95 - and self.quant_config is not None - and self.quant_config.get_name() == "quark" - ): - w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = ( - quark_post_load_weights(self_attn, w, "mxfp4") - ) - - if not use_deep_gemm_bmm: - self_attn.w_kc = bind_or_assign( - self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) - ) - w_vc = w_vc.contiguous().transpose(1, 2) - if _is_npu: - w_vc = w_vc.contiguous() - self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc) - if ( - hasattr(self_attn.kv_b_proj, "weight_scale") - and self_attn.w_scale is None - ): - self_attn.w_scale = bind_or_assign( - self_attn.w_scale, self_attn.kv_b_proj.weight_scale - ) - if _is_hip: - self_attn.w_scale *= 2.0 - # TODO: remove this after adding FP8 support in bmm cpu kernel - if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: - self_attn.w_kc = ( - self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale - ) - self_attn.w_vc = ( - self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale - ) - else: - num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] - num_tiles_n = self_attn.v_head_dim // weight_block_size[0] - ws_kc, ws_vc = block_scale.unflatten( - 0, (-1, (num_tiles_k + num_tiles_n)) - ).split([num_tiles_k, num_tiles_n], dim=1) - self_attn.w_scale_k = bind_or_assign( - self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() - ) - self_attn.w_scale_v = bind_or_assign( - self_attn.w_scale_v, ws_vc.contiguous() - ) - self_attn.w_kc = bind_or_assign( - self_attn.w_kc, w_kc.transpose(1, 2).contiguous() - ) - self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) - self_attn.use_deep_gemm_bmm = True - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): - - if is_nextn: - if hasattr(self.config, "num_nextn_predict_layers"): - num_nextn_layers = self.config.num_nextn_predict_layers - assert num_nextn_layers == 1, "Only 1 nextn layer is supported" - # compatible with old design - nextn_layer_id = ( - 0 - if self.config.num_hidden_layers == 1 - else self.config.num_hidden_layers - ) - else: - raise ValueError("num_nextn_predict_layers is not in the config") - - weights = self._maybe_quant_weights_to_fp8_ue8m0( - weights, NVFP4_CKPT_FP8_ATTN_QUANT_MODULES, is_nextn - ) - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, - ) - # Params for special naming rules in mixed-precision models, for example: - # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, - # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. - if self.quant_config and self.quant_config.get_name() == "w4afp8": - expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( - num_experts=self.config.n_routed_experts - ) - - # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None - fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( - self.config.q_lora_rank is not None - ) - cached_a_proj = {} if fuse_qkv_a_proj else None - - if is_nextn: - nextn_layer_prefix = f"model.layers.{nextn_layer_id}" - nextn_spec_weight_names = [ - "shared_head.norm", - "eh_proj", - "enorm", - "hnorm", - ] - - if self.num_fused_shared_experts > 0: - assert self.num_fused_shared_experts == 1 - log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - params_dict = dict(self.named_parameters()) - weight_names = [] - for name, loaded_weight in weights: - use_async_loading = should_async_load(loaded_weight) - layer_id = get_layer_id(name) - if ( - layer_id is not None - and hasattr(self.model, "start_layer") - and ( - layer_id < self.model.start_layer - or layer_id >= self.model.end_layer - ) - ): - continue - if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: - name = name.replace( - "mlp.shared_experts", - f"mlp.experts.{self.config.n_routed_experts}", - ) - - weight_names.append(name) - - if not is_nextn: - if hasattr(self.config, "num_nextn_predict_layers"): - num_nextn_layers = self.config.num_nextn_predict_layers - if num_nextn_layers > 0 and name.startswith("model.layers"): - name_list = name.split(".") - if ( - len(name_list) >= 3 - and int(name_list[2]) >= self.config.num_hidden_layers - ): - continue - else: - if not name.startswith(nextn_layer_prefix): - continue - - # Use shared head and embed weights from target model - if "shared_head.head" in name or "embed_tokens" in name: - continue - - is_decoder = True - # For nextn specific weights - for weight_name in nextn_spec_weight_names: - if weight_name in name: - name = name.replace(nextn_layer_prefix, "model") - is_decoder = False - break - # For decoder layer weights - if is_decoder: - name = name.replace(nextn_layer_prefix, "model.decoder") - - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - if _is_npu: - name = name.replace("weight_packed", "weight") - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - maybe_executor_submit( - executor=executor, - futures=futures, - use_async=use_async_loading, - func=weight_loader, - func_args=(param, loaded_weight, shard_id), - ) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - if _is_npu: - name = name.replace("weight_packed", "weight") - name = name.replace(weight_name, param_name) - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - maybe_executor_submit( - executor=executor, - futures=futures, - use_async=use_async_loading, - func=weight_loader, - func_args=( - param, - loaded_weight, - name, - ), - func_kwargs={ - "shard_id": shard_id, - "expert_id": expert_id, - }, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip loading embed_tokens if not first rank in pipeline parallelism - if ".embed_tokens." in name and not self.pp_group.is_first_rank: - continue - # Skip loading norm if not last rank in pipeline parallelism - if ".norm." in name and not self.pp_group.is_last_rank: - continue - if fuse_qkv_a_proj and ( - "q_a_proj" in name or "kv_a_proj_with_mqa" in name - ): - cached_a_proj[name] = loaded_weight - q_a_proj_name = ( - name - if "q_a_proj" in name - else name.replace("kv_a_proj_with_mqa", "q_a_proj") - ) - kv_a_proj_name = ( - name - if "kv_a_proj_with_mqa" in name - else name.replace("q_a_proj", "kv_a_proj_with_mqa") - ) - - # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter - if ( - q_a_proj_name in cached_a_proj - and kv_a_proj_name in cached_a_proj - ): - q_a_proj_weight = cached_a_proj[q_a_proj_name] - kv_a_proj_weight = cached_a_proj[kv_a_proj_name] - - if q_a_proj_weight.shape == torch.Size( - [] - ) and kv_a_proj_weight.shape == torch.Size([]): - fused_weight = q_a_proj_weight - else: - cat_dim = 0 - if self.quant_config is not None and ( - self.quant_config.get_name() == "awq" - or self.quant_config.get_name() == "awq_marlin" - or self.quant_config.get_name() == "moe_wna16" - ): - cat_dim = 1 - - fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim - ) - - param_name = ( - name.replace( - "q_a_proj", "fused_qkv_a_proj_with_mqa" - ) - if "q_a_proj" in name - else name.replace( - "kv_a_proj_with_mqa", - "fused_qkv_a_proj_with_mqa", - ) - ) - param = params_dict[param_name] - - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - maybe_executor_submit( - executor=executor, - futures=futures, - use_async=use_async_loading, - func=weight_loader, - func_args=(param, fused_weight), - ) - cached_a_proj.pop(q_a_proj_name) - cached_a_proj.pop(kv_a_proj_name) - else: - if ( - "k_scale" in name or "v_scale" in name - ) and name not in params_dict: - # modelopt attn kv scale is named differently - for scale in ["k_scale", "v_scale"]: - if scale in name: - name = name.replace( - f"{scale[0]}_proj", "attn_mqa" - ) - break - if name not in params_dict: - # modelopt ckpt contains not needed weights for MTP module: - # model.decoder.self_attn.attn_mqa.v_scale and - # model.decoder.self_attn.attn_mqa.k_scale - logger.warning(f"{name} not found in params_dict.") - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - maybe_executor_submit( - executor=executor, - futures=futures, - use_async=use_async_loading, - func=weight_loader, - func_args=(param, loaded_weight), - ) - - # Wait for all tasks to complete and raise any exceptions. - for future in concurrent.futures.as_completed(futures): - future.result() - - self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + self.do_load_weights(weights, is_nextn) def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight @@ -3420,77 +2906,6 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] - # Mark the ue8m0 flag of nextn moe weights as True to avoid requantization - def _mark_nextn_moe_weights_as_ue8m0(self): - experts = self.model.decoder.mlp.experts - w13_scale = ( - experts.w13_weight_scale_inv - if hasattr(experts, "w13_weight_scale_inv") - else experts.w13_weight_scale - ) - w2_scale = ( - experts.w2_weight_scale_inv - if hasattr(experts, "w2_weight_scale_inv") - else experts.w2_weight_scale - ) - w13_scale.format_ue8m0 = True - w2_scale.format_ue8m0 = True - - def _maybe_quant_weights_to_fp8_ue8m0( - self, weights, attn_quant_modules, is_nextn=False - ): - # Quantize some weights to fp8 ue8m0 for DeepSeek nvfp4 checkpoint - partial_names = [] - nextn_layer_id = ( - 0 if self.config.num_hidden_layers == 1 else self.config.num_hidden_layers - ) - weights_dict = dict(weights) - weight_block_size = [128, 128] - - if envs.SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN.get(): - layer_ids = ( - list(range(self.config.num_hidden_layers)) - if not is_nextn - else [nextn_layer_id] - ) - for layer_id in layer_ids: - for stem in attn_quant_modules: - partial_names.append(f"model.layers.{layer_id}.self_attn.{stem}") - - if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): - for expert_sub_name in [ - "shared_experts", - *[ - f"experts.{expert_id}" - for expert_id in range(self.config.n_routed_experts) - ], - ]: - for stem in [ - "gate_proj", - "up_proj", - "down_proj", - ]: - partial_names.append( - f"model.layers.{nextn_layer_id}.mlp.{expert_sub_name}.{stem}" - ) - - if len(partial_names) > 0: - for partial_name in tqdm.tqdm( - partial_names, - desc="quant weights to fp8 ue8m0", - ): - original_weight = weights_dict[f"{partial_name}.weight"] - out_w, out_s = quant_weight_ue8m0( - original_weight, weight_block_size=weight_block_size - ) - weights_dict[f"{partial_name}.weight"] = out_w - weights_dict[f"{partial_name}.weight_scale_inv"] = out_s - - if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): - self._mark_nextn_moe_weights_as_ue8m0() - - return list(weights_dict.items()) - class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass