diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 3569e68054c6..ef93e18509b3 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -304,12 +304,13 @@ def _connect(self, endpoint: str, is_ipv6: bool = False): return socket def get_mha_kv_ptrs_with_pp( - self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int] + self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int], dst_non_draft_kv_data_lens: int, ) -> Tuple[List[int], List[int], List[int], List[int], int]: start_layer = self.kv_args.prefill_start_layer num_kv_layers = len(src_kv_ptrs) // 2 end_layer = start_layer + num_kv_layers dst_num_total_layers = len(dst_kv_ptrs) // 2 + dst_non_draft_num_total_layers = dst_non_draft_kv_data_lens // 2 src_k_ptrs = src_kv_ptrs[:num_kv_layers] src_v_ptrs = src_kv_ptrs[num_kv_layers:] if num_kv_layers == dst_num_total_layers: @@ -318,6 +319,7 @@ def get_mha_kv_ptrs_with_pp( elif ( num_kv_layers < dst_num_total_layers and dst_num_total_layers % num_kv_layers != 0 + and self.kv_args.prefill_pp_size == 1 ): # Case: Decode has draft model KV while Prefill is deployed without speculative decoding # dst_kv_ptrs layout: [K_main..., V_main..., draft_K..., draft_V...] @@ -331,7 +333,7 @@ def get_mha_kv_ptrs_with_pp( # Decode pp size should be equal to prefill pp size or 1 dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer] dst_v_ptrs = dst_kv_ptrs[ - dst_num_total_layers + start_layer : dst_num_total_layers + end_layer + dst_non_draft_num_total_layers + start_layer : dst_non_draft_num_total_layers + end_layer ] layers_current_pp_stage = len(src_k_ptrs) return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 7f2de88cb7fd..d2afa5113ad3 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -297,6 +297,7 @@ def _init_kv_manager(self) -> CommonKVManager: kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) + kv_args.non_draft_kv_data_lens = len(kv_data_ptrs) if self.draft_token_to_kv_pool is not None: # We should also transfer draft model kv cache. The indices are # always shared with a target model. diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index a223a2578ea2..91e6e7266d7c 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -107,6 +107,7 @@ class KVArgsRegisterInfo: dst_port: int mooncake_session_id: str dst_kv_ptrs: list[int] + dst_non_draft_kv_data_lens: int dst_aux_ptrs: list[int] dst_state_data_ptrs: list[int] dst_tp_rank: int @@ -139,6 +140,7 @@ def from_zmq(cls, msg: List[bytes]): if len(msg) > 11 and len(msg[11]) > 0 else [] ), + dst_non_draft_kv_data_lens=int(msg[12].decode("ascii")), ) @@ -247,6 +249,7 @@ def _send_kvcache_generic( mooncake_session_id: str, src_data_ptrs: list[int], dst_data_ptrs: list[int], + dst_non_draft_kv_data_lens : int, item_lens: list[int], prefill_data_indices: npt.NDArray[np.int32], dst_data_indices: npt.NDArray[np.int32], @@ -278,7 +281,7 @@ def _send_kvcache_generic( ] else: src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( - self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs) + self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs, dst_non_draft_kv_data_lens) ) # item_lens structure: [k_layer0, k_layer1, ..., k_layerN, v_layer0, v_layer1, ..., v_layerN] # Use correct item lengths for K and V separately @@ -355,6 +358,7 @@ def send_kvcache( mooncake_session_id: str, prefill_kv_indices: npt.NDArray[np.int32], dst_kv_ptrs: list[int], + dst_non_draft_kv_data_lens : int, dst_kv_indices: npt.NDArray[np.int32], executor: concurrent.futures.ThreadPoolExecutor, ): @@ -362,6 +366,7 @@ def send_kvcache( mooncake_session_id=mooncake_session_id, src_data_ptrs=self.kv_args.kv_data_ptrs, dst_data_ptrs=dst_kv_ptrs, + dst_non_draft_kv_data_lens=dst_non_draft_kv_data_lens, item_lens=self.kv_args.kv_item_lens, prefill_data_indices=prefill_kv_indices, dst_data_indices=dst_kv_indices, @@ -373,6 +378,7 @@ def send_kvcache_slice( mooncake_session_id: str, prefill_kv_indices: npt.NDArray[np.int32], dst_kv_ptrs: list[int], + dst_non_draft_kv_data_lens : int, dst_kv_indices: npt.NDArray[np.int32], dst_tp_rank: int, dst_attn_tp_size: int, @@ -426,7 +432,7 @@ def send_kvcache_slice( dst_head_start_offset = 0 src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( - self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs, dst_non_draft_kv_data_lens) ) # Calculate precise byte offset and length for the sub-slice within the token @@ -589,6 +595,7 @@ def maybe_send_extra( req: TransferInfo, prefill_state_indices: list[int], dst_state_data_ptrs: list[int], + dst_non_draft_kv_data_lens: int, executor: concurrent.futures.ThreadPoolExecutor, target_rank_registration_info: Optional[KVArgsRegisterInfo] = None, ): @@ -597,6 +604,41 @@ def maybe_send_extra( if state_type == "mamba": # Check if we need slice transfer for different TP sizes + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + prefill_state_item_lens = self.kv_args.state_item_lens + src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) + dst_state_item_lens = ( + target_rank_registration_info.dst_state_item_lens + if target_rank_registration_info is not None + else [] + ) + dst_state_dim_per_tensor = ( + target_rank_registration_info.dst_state_dim_per_tensor + if target_rank_registration_info is not None + else [] + ) + mamba_layer_ids = self.kv_args.total_mamba_layer_ids + layer_indices = self.kv_args.mamba_layer_ids + total_layers = len(mamba_layer_ids) + num_tensors = len(prefill_state_data_ptrs) // total_layers + if num_tensors * total_layers == len(prefill_state_data_ptrs): + indices = [ + base + idx + for base in range(0, total_layers * num_tensors, total_layers) + for idx in layer_indices + ] + def slice_list(values): + if not values: + return [] + return [values[i] for i in indices] + + prefill_state_data_ptrs = slice_list(prefill_state_data_ptrs) + prefill_state_item_lens = slice_list(prefill_state_item_lens) + src_state_dim_per_tensor = slice_list(src_state_dim_per_tensor) + dst_state_data_ptrs = slice_list(dst_state_data_ptrs) + dst_state_item_lens = slice_list(dst_state_item_lens) + dst_state_dim_per_tensor = slice_list(dst_state_dim_per_tensor) + if ( target_rank_registration_info is not None and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size @@ -605,16 +647,21 @@ def maybe_send_extra( req, prefill_state_indices, dst_state_data_ptrs, - target_rank_registration_info.dst_state_item_lens, - target_rank_registration_info.dst_state_dim_per_tensor, + dst_state_item_lens, + dst_state_dim_per_tensor, target_rank_registration_info.dst_tp_rank, target_rank_registration_info.dst_attn_tp_size, + prefill_state_data_ptrs=prefill_state_data_ptrs, + prefill_state_item_lens=prefill_state_item_lens, + src_state_dim_per_tensor=src_state_dim_per_tensor, ) else: return self._send_mamba_state( req, prefill_state_indices, dst_state_data_ptrs, + prefill_state_data_ptrs=prefill_state_data_ptrs, + prefill_state_item_lens=prefill_state_item_lens, ) elif state_type in ["swa", "nsa"]: # SWA and NSA hybrid models do not support different TP sizes yet @@ -640,6 +687,7 @@ def maybe_send_extra( mooncake_session_id=req.mooncake_session_id, src_data_ptrs=self.kv_args.state_data_ptrs, dst_data_ptrs=dst_state_data_ptrs, + dst_non_draft_kv_data_lens=dst_non_draft_kv_data_lens, item_lens=self.kv_args.state_item_lens, prefill_data_indices=prefill_state_indices, dst_data_indices=dst_state_indices, @@ -653,13 +701,17 @@ def _send_mamba_state( req: TransferInfo, prefill_mamba_index: list[int], dst_state_data_ptrs: list[int], + prefill_state_data_ptrs: Optional[list[int]] = None, + prefill_state_item_lens: Optional[list[int]] = None, ): """Transfer Mamba states.""" assert len(prefill_mamba_index) == 1, "Mamba should have single state index" transfer_blocks = [] - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens + if prefill_state_data_ptrs is None: + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + if prefill_state_item_lens is None: + prefill_state_item_lens = self.kv_args.state_item_lens for i, dst_state_ptr in enumerate(dst_state_data_ptrs): length = prefill_state_item_lens[i] @@ -678,6 +730,9 @@ def _send_mamba_state_slice( dst_state_dim_per_tensor: list[int], dst_tp_rank: int, dst_attn_tp_size: int, + prefill_state_data_ptrs: Optional[list[int]] = None, + prefill_state_item_lens: Optional[list[int]] = None, + src_state_dim_per_tensor: Optional[list[int]] = None, ): """Transfer Mamba states with TP slice support. @@ -696,9 +751,12 @@ def _send_mamba_state_slice( assert len(prefill_mamba_index) == 1, "Mamba should have single state index" transfer_blocks = [] - prefill_state_data_ptrs = self.kv_args.state_data_ptrs - prefill_state_item_lens = self.kv_args.state_item_lens - src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) + if prefill_state_data_ptrs is None: + prefill_state_data_ptrs = self.kv_args.state_data_ptrs + if prefill_state_item_lens is None: + prefill_state_item_lens = self.kv_args.state_item_lens + if src_state_dim_per_tensor is None: + src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", []) # If no dimension info available, fall back to regular transfer if not src_state_dim_per_tensor or not dst_state_dim_per_tensor: @@ -824,6 +882,7 @@ def transfer_worker( req.mooncake_session_id, kv_chunk.prefill_kv_indices, target_rank_registration_info.dst_kv_ptrs, + target_rank_registration_info.dst_non_draft_kv_data_lens, chunked_dst_kv_indice, executor, ) @@ -832,6 +891,7 @@ def transfer_worker( req.mooncake_session_id, kv_chunk.prefill_kv_indices, target_rank_registration_info.dst_kv_ptrs, + target_rank_registration_info.dst_non_draft_kv_data_lens, chunked_dst_kv_indice, target_rank_registration_info.dst_tp_rank, target_rank_registration_info.dst_attn_tp_size, @@ -867,6 +927,7 @@ def transfer_worker( req, kv_chunk.state_indices, target_rank_registration_info.dst_state_data_ptrs, + target_rank_registration_info.dst_non_draft_kv_data_lens, executor, target_rank_registration_info, ) @@ -1256,6 +1317,7 @@ def _register_kv_args(self): dst_tp_rank = str(tp_rank).encode("ascii") dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii") + non_draft_kv_data_lens = str(self.kv_mgr.kv_args.non_draft_kv_data_lens).encode("ascii") sock, lock = self._connect_to_bootstrap_server(bootstrap_info) with lock: @@ -1273,6 +1335,7 @@ def _register_kv_args(self): dst_kv_item_len, packed_state_item_lens, packed_state_dim_per_tensor, + non_draft_kv_data_lens, ] ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index e8595ad6c469..629cd079108a 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -138,6 +138,8 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size kv_args.prefill_pp_size = self.pp_size kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer + kv_args.total_mamba_layer_ids = self.token_to_kv_pool.total_mamba_layer_ids + kv_args.mamba_layer_ids = self.token_to_kv_pool.mamba_layer_ids kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..5436cca64206 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=512,N=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 11c45ff1cc8f..0170b2338e93 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1223,6 +1223,8 @@ def __init__( head_num: int, head_dim: int, full_attention_layer_ids: List[int], + total_mamba_layer_ids: List[int], + mamba_layer_ids: List[int], enable_kvcache_transpose: bool, device: str, mamba_pool: MambaPool, @@ -1231,14 +1233,18 @@ def __init__( use_mla: bool = False, kv_lora_rank: int = None, qk_rope_head_dim: int = None, + start_layer: int = None, + end_layer: int = None, ): self.size = size self.dtype = dtype self.device = device self.full_layer_nums = len(full_attention_layer_ids) + self.total_mamba_layer_ids = total_mamba_layer_ids + self.mamba_layer_ids = mamba_layer_ids self.page_size = page_size - # TODO support pp? - self.start_layer = 0 + self.start_layer = start_layer + self.end_layer = end_layer self.head_num = head_num self.head_dim = head_dim self.mamba_pool = mamba_pool diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 2af60158f7cd..a9b9a3c3dac7 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import torch +import bisect from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa from sglang.srt.distributed.parallel_state import get_world_group @@ -618,6 +619,7 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): "kv_lora_rank": self.model_config.kv_lora_rank, "qk_rope_head_dim": self.model_config.qk_rope_head_dim, } + total_mamba_layer_ids = list(config.mamba2_cache_params.layers) self.token_to_kv_pool = HybridLinearKVPool( page_size=self.page_size, size=self.max_total_num_tokens, @@ -636,11 +638,23 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): if self.start_layer <= i < self.end_layer ] ), + total_mamba_layer_ids=total_mamba_layer_ids, + mamba_layer_ids=( + [] + if self.is_draft_worker + else [ + i + for i, layer_id in enumerate(total_mamba_layer_ids) + if self.start_layer <= layer_id < self.end_layer + ] + ), enable_kvcache_transpose=False, device=self.device, mamba_pool=self.req_to_token_pool.mamba_pool, enable_memory_saver=self.server_args.enable_memory_saver, use_mla=self.use_mla_backend, + start_layer=bisect.bisect_left(config.full_attention_layer_ids, self.start_layer), + end_layer=bisect.bisect_left(config.full_attention_layer_ids, self.end_layer), **extra_args, ) else: diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 306adbccb453..14f81dea9e2f 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -72,6 +72,7 @@ from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration # Utils +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs from sglang.srt.utils.hf_transformers_utils import get_processor @@ -680,6 +681,8 @@ def __init__( org_num_embeddings=config.vocab_size, enable_tp=not is_dp_attention_enabled(), ) + else: + self.embed_tokens = PPMissingLayer() # Decoder layers def get_layer(idx: int, prefix: str): @@ -697,15 +700,19 @@ def get_layer(idx: int, prefix: str): alt_stream=alt_stream, ) - self.layers = make_layers( + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, get_layer, + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix=f"{prefix}.layers", ) # Final normalization if self.pp_group.is_last_rank: self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens @@ -733,7 +740,7 @@ def forward( residual = pp_proxy_tensors["residual"] # Pass through decoder layers - for layer_idx in range(len(self.layers)): + for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx] with get_global_expert_distribution_recorder().with_current_layer( layer_idx @@ -797,6 +804,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue + if "embed_tokens" in name and not self.pp_group.is_first_rank: + continue + if "lm_head" in name and not self.pp_group.is_last_rank: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -918,6 +936,17 @@ def load_fused_expert_weights( name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self, "start_layer") + and (layer_id < self.start_layer or layer_id >= self.end_layer) + ): + continue + if "embed_tokens" in name and not self.pp_group.is_first_rank: + continue + if "lm_head" in name and not self.pp_group.is_last_rank: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if "experts.gate_up_proj" in name or "experts.down_proj" in name: @@ -1077,6 +1106,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + if "embed_tokens" in name and not self.pp_group.is_first_rank: + continue + if "lm_head" in name and not self.pp_group.is_last_rank: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -1220,6 +1260,17 @@ def load_fused_expert_weights( name = name.replace(r"model.language_model.", r"model.") if ".self_attn." in name: name = name.replace(".self_attn", "") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + if "embed_tokens" in name and not self.pp_group.is_first_rank: + continue + if "lm_head" in name and not self.pp_group.is_last_rank: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if name.endswith("experts.gate_up_proj") or name.endswith( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 09c443b894c1..e02d9b4b333a 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1091,6 +1091,19 @@ def separate_deepstack_embeds(self, embedding): input_deepstack_embeds = embedding[:, separate_index:] return input_embeds, input_deepstack_embeds + @property + def start_layer(self) -> int: + return getattr(getattr(self, "model", None), "start_layer", 0) + + @property + def end_layer(self) -> int: + model = getattr(self, "model", None) + end_layer = getattr(model, "end_layer", None) + if end_layer is not None: + return end_layer + cfg = getattr(model, "config", None) + return int(getattr(cfg, "num_hidden_layers", 0)) + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): pattern = MultiModalityDataPaddingPatternMultimodalTokens() return pattern.pad_input_tokens(input_ids, mm_inputs)