diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 45b55fa3bd09..107b73378c72 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -20,11 +20,6 @@ import torch import torch.nn as nn -import triton - -from sglang.jit_kernel.triton.gdn_fused_proj import ( - fused_qkvzba_split_reshape_cat_contiguous, -) # Configs from sglang.srt.configs.qwen3_5 import ( @@ -59,10 +54,6 @@ RowParallelLinear, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.parameter import ( - BlockQuantScaleParameter, - PerTensorScaleParameter, -) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_linear_attention import RadixLinearAttention @@ -79,14 +70,11 @@ # Models from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration -from sglang.srt.server_args import get_global_server_args # Utils from sglang.srt.utils import ( LazyValue, add_prefix, - cpu_has_amx_support, - is_cpu, is_cuda, is_npu, make_layers, @@ -97,9 +85,6 @@ logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() -_is_cpu = is_cpu() -_is_amx_available = cpu_has_amx_support() - cached_get_processor = lru_cache(get_processor) @@ -144,47 +129,63 @@ def __init__( ) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # projection of the input hidden states - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, + # Split projection layers (following vLLM's implementation) + # Instead of fused in_proj_qkvz and in_proj_ba, use separate layers + self.in_proj_qkv = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim], + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_qkvz", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_qkv", prefix), ) - - self.in_proj_ba = self.create_ba_proj( - hidden_size=self.hidden_size, - num_v_heads=self.num_v_heads, + self.in_proj_z = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.value_dim, + bias=False, quant_config=quant_config, - prefix=add_prefix("in_proj_ba", prefix), tp_rank=self.attn_tp_rank, tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_z", prefix), + ) + self.in_proj_b = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_b", prefix), + ) + self.in_proj_a = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + prefix=add_prefix("in_proj_a", prefix), ) - - # Override weight loaders for packed checkpoint format. - # Important: for FP8, this must cover not only `.weight` but also - # `weight_scale_inv` / `weight_scale` / `input_scale` if present. - self._bind_packed_weight_loaders(self.in_proj_qkvz) - self._bind_packed_weight_loaders(self.in_proj_ba) # Conv1d weight loader setup query_key_settings = (self.key_dim, 0, False) value_settings = (self.value_dim, 0, False) - self._override_weight_loader( + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( self.conv1d.weight, - mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.attn_tp_size, - self.attn_tp_rank, - ), + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.attn_tp_size, + self.attn_tp_rank, + ) + }, ) # State parameters @@ -201,6 +202,7 @@ def __init__( conv_weights = self.conv1d.weight.view( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) + # RadixLinearAttention layer self.attn = RadixLinearAttention( layer_id=layer_id, num_q_heads=self.num_k_heads // self.attn_tp_size, @@ -216,6 +218,7 @@ def __init__( dt_bias=self.dt_bias, ) + # Normalization layer self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, @@ -225,6 +228,7 @@ def __init__( dtype=config.torch_dtype, ) + # Output projection self.out_proj = RowParallelLinear( self.value_dim, self.hidden_size, @@ -237,190 +241,16 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - @staticmethod - def _override_weight_loader(param, loader): - """Robustly override loader for: - 1) BasevLLMParameter subclasses: real storage is `_weight_loader` - 2) regular Parameters that already have mutable `weight_loader` - 3) regular Parameters without `weight_loader` yet - """ - if hasattr(param, "_weight_loader"): - # FP8 / quantized BasevLLMParameter path - param._weight_loader = loader - return - - if hasattr(param, "weight_loader"): - # Regular parameter/tensor that already has a mutable attr. - # Do NOT call set_weight_attrs here, because it asserts when - # overwriting an existing attribute. - param.weight_loader = loader - return - - # Fresh attribute on a normal tensor/Parameter - set_weight_attrs(param, {"weight_loader": loader}) - - def _bind_packed_weight_loaders(self, module): - """Bind packed-checkpoint-aware loaders to all relevant params of a merged module.""" - for attr_name in ("weight", "weight_scale_inv", "weight_scale", "input_scale"): - param = getattr(module, attr_name, None) - if param is None: - continue - original_loader = getattr(param, "weight_loader", None) - if original_loader is None: - continue - wrapped_loader = self._make_packed_weight_loader(module, original_loader) - self._override_weight_loader(param, wrapped_loader) - - @staticmethod - def _get_split_sizes_for_param(module, param, loaded_shard_id): - """Return checkpoint-side split sizes for this param type.""" - if isinstance(param, BlockQuantScaleParameter): - # Split by output blocks, not raw output sizes. - block_n, _ = module.quant_method.quant_config.weight_block_size - block_n = 1 if getattr(param, "format_ue8m0", False) else block_n - return [ - (module.output_sizes[idx] + block_n - 1) // block_n - for idx in loaded_shard_id - ] - - if isinstance(param, PerTensorScaleParameter): - # One logical scale per logical shard. - return [1 for _ in loaded_shard_id] - - # Normal weight / non-block quant tensor - return [module.output_sizes[idx] for idx in loaded_shard_id] - - @classmethod - def _make_packed_weight_loader(cls, module, original_weight_loader): - """Wrap the param's original loader so split checkpoints: - - in_proj_qkv + in_proj_z -> merged in_proj_qkvz - - in_proj_b + in_proj_a -> merged in_proj_ba - can load correctly for both normal and FP8 params. - """ - - def weight_loader(param, loaded_weight, loaded_shard_id=None): - # Only intercept split-checkpoint tuple shards. - # int shard_id and None should preserve original behavior. - if isinstance(loaded_shard_id, tuple): - split_sizes = cls._get_split_sizes_for_param( - module, param, loaded_shard_id - ) - - if len(loaded_weight.shape) == 0: - # Scalar only makes sense for a single logical shard. - assert len(split_sizes) == 1 and split_sizes[0] == 1, ( - f"Unexpected scalar for tuple shard load: " - f"{loaded_shard_id=}, {split_sizes=}" - ) - chunks = [loaded_weight.reshape(1)] - else: - split_dim = getattr(param, "output_dim", 0) - chunks = loaded_weight.split(split_sizes, dim=split_dim) - - assert len(chunks) == len(loaded_shard_id), ( - f"Chunk/shard mismatch: {len(chunks)=}, " - f"{len(loaded_shard_id)=}, {split_sizes=}" - ) - - for idx, chunk in zip(loaded_shard_id, chunks): - # Delegate each chunk to the param's original int-shard loader. - original_weight_loader(param, chunk, idx) - return - - return original_weight_loader(param, loaded_weight, loaded_shard_id) - - return weight_loader - - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - tp_rank=tp_rank, - tp_size=tp_size, - ) - - def create_ba_proj( - self, - hidden_size: int, - num_v_heads: int, - quant_config: QuantizationConfig | None, - prefix: str, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, - ) -> MergedColumnParallelLinear: - # Qwen3.5 has separate in_proj_b and in_proj_a weights in the - # checkpoint, which are loaded into the fused in_proj_ba parameter - # via stacked_params_mapping with shard_id 0 and 1 respectively. - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[num_v_heads, num_v_heads], - bias=False, - quant_config=quant_config, - prefix=prefix, - tp_rank=tp_rank, - tp_size=tp_size, - ) - def fix_query_key_value_ordering( self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, + mixed_qkv, + z, + b, + a, ): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - k_tp = self.key_dim // self.attn_tp_size - v_tp = self.value_dim // self.attn_tp_size - nv_tp = self.num_v_heads // self.attn_tp_size - - # Directly split, no head group reshape - query, key, value, z = mixed_qkvz.split([k_tp, k_tp, v_tp, v_tp], dim=-1) - b, a = mixed_ba.split([nv_tp, nv_tp], dim=-1) - - # value / z reshape to (seq, num_v_heads/tp, head_v_dim) - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - - return query, key, value, z, b, a - - def _forward_input_proj(self, hidden_states: torch.Tensor): - if ( - _is_cpu - or _is_npu - or not get_global_server_args().disable_piecewise_cuda_graph - ): - DUAL_STREAM_TOKEN_THRESHOLD = 0 - else: - DUAL_STREAM_TOKEN_THRESHOLD = 1024 - - seq_len, _ = hidden_states.shape - if ( - self.alt_stream is not None - and get_is_capture_mode() - and seq_len < DUAL_STREAM_TOKEN_THRESHOLD - ): - current_stream = torch.cuda.current_stream() - self.alt_stream.wait_stream(current_stream) - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - with torch.cuda.stream(self.alt_stream): - projected_states_ba, _ = self.in_proj_ba(hidden_states) - current_stream.wait_stream(self.alt_stream) - else: - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - return projected_states_qkvz, projected_states_ba + raise NotImplementedError( + "Qwen3.5 Series dont need to fix query key value ordering" + ) def forward( self, @@ -433,60 +263,30 @@ def forward( 2. Core attention (custom op) 3. Output projection """ - projected_states_qkvz, projected_states_ba = self._forward_input_proj( - hidden_states - ) + seq_len, _ = hidden_states.shape + + mixed_qkv, _ = self.in_proj_qkv(hidden_states) + z, _ = self.in_proj_z(hidden_states) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b, _ = self.in_proj_b(hidden_states) + a, _ = self.in_proj_a(hidden_states) + + b = b.contiguous() + a = a.contiguous() - if self.num_v_heads // self.num_k_heads in [1, 2, 4] and not _is_cpu: - mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat_contiguous( - projected_states_qkvz, - projected_states_ba, - triton.cdiv(self.num_k_heads, self.attn_tp_size), - triton.cdiv(self.num_v_heads, self.attn_tp_size), - self.head_k_dim, - self.head_v_dim, - ) - elif _is_cpu and _is_amx_available: - mixed_qkv, z, b, a = ( - torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( - projected_states_qkvz, - projected_states_ba, - self.num_k_heads // self.attn_tp_size, - self.num_v_heads // self.attn_tp_size, - self.head_k_dim, - self.head_v_dim, - ) - ) - else: - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba - ) - query, key, value = map( - lambda x: x.reshape(x.shape[0], -1), (query, key, value) - ) - mixed_qkv = torch.cat((query, key, value), dim=-1) core_attn_out = self.attn( - forward_batch, + forward_batch=forward_batch, mixed_qkv=mixed_qkv, a=a, b=b, ) z_shape_og = z.shape - # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - - # Add padding for DP-Attn - if core_attn_out.shape != z.shape: - core_attn_out_pad = torch.zeros_like(z) - core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out - core_attn_out = core_attn_out_pad - core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) - + core_attn_out = core_attn_out.flatten(-2) # ... h d -> ... (h d) output, _ = self.out_proj(core_attn_out) return output @@ -1018,11 +818,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -1099,11 +894,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -1337,11 +1127,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN fused projections - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] loaded_params: Set[str] = set() @@ -1438,11 +1223,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - # GDN fused projections - ("in_proj_qkvz.", "in_proj_qkv.", (0, 1, 2)), - ("in_proj_qkvz.", "in_proj_z.", 3), - ("in_proj_ba.", "in_proj_b.", 0), - ("in_proj_ba.", "in_proj_a.", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales