diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 9d60ea675724..6a004c0ffc1f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -773,6 +773,8 @@ def _parse_quant_hf_config(self): return quant_cfg def _find_quant_modelslim_config(self): + if self.is_draft_model: + return None quant_config_file = Path(self.model_path, "quant_model_description.json") quant_cfg = None if quant_config_file.is_file(): diff --git a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py index 16dc169ab0b9..689c8c95f111 100644 --- a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py +++ b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py @@ -3,6 +3,7 @@ import torch import torch_npu +from sgl_kernel_npu.norm.fused_split_qk_norm import fused_split_qk_norm from sglang.srt.environ import envs from sglang.srt.hardware_backend.npu.attention.mla_preprocess import ( @@ -323,39 +324,63 @@ def forward_dsa_prepare_npu( ) else: fused_qkv_a_proj_out = m.fused_qkv_a_proj_with_mqa(hidden_states)[0] - q, latent_cache = fused_qkv_a_proj_out.split( - [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 - ) - - # overlap qk norm - q = m.q_a_layernorm(q) - if ( - _use_ag_after_qlora - and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED - and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL - ): - q = scattered_to_tp_attn_full(q, forward_batch) - latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch) - q_lora = q.clone() # required for topk_indices - - q_event = None - if m.alt_stream is not None: - m.alt_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(m.alt_stream): + if m.rotary_emb.is_neox_style: + q, latent_cache = fused_qkv_a_proj_out.split( + [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 + ) + # overlap qk norm + q = m.q_a_layernorm(q) + if ( + _use_ag_after_qlora + and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED + and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL + ): + q = scattered_to_tp_attn_full(q, forward_batch) + latent_cache = scattered_to_tp_attn_full(latent_cache, forward_batch) + q_lora = q.clone() # required for topk_indices + + q_event = None + if m.alt_stream is not None: + m.alt_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(m.alt_stream): + q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) + # record q to ensure memory space will not be released + q.record_stream(m.alt_stream) + q_event = m.alt_stream.record_event() + else: q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) - # record q to ensure memory space will not be released - q.record_stream(m.alt_stream) - q_event = m.alt_stream.record_event() + + k_nope, k_pe = latent_cache.unsqueeze(1).split( + [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 + ) + k_nope = m.kv_a_layernorm(k_nope) + # main stream waits for the completion of the event on the alt stream to ensure data dependency is complete + if q_event is not None: + torch.npu.current_stream().wait_event(q_event) else: - q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) + if fused_qkv_a_proj_out.shape[0] < 65535: + q_lora, k_nope, k_pe = fused_split_qk_norm( + fused_qkv_a_proj_out, + m.q_a_layernorm, + m.kv_a_layernorm, + m.q_lora_rank, + m.kv_lora_rank, + m.qk_rope_head_dim, + eps=m.q_a_layernorm.variance_epsilon, + ) + else: + q, latent_cache = fused_qkv_a_proj_out.split( + [m.q_lora_rank, m.kv_lora_rank + m.qk_rope_head_dim], dim=-1 + ) + # overlap qk norm + q = m.q_a_layernorm(q) - k_nope, k_pe = latent_cache.unsqueeze(1).split( - [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 - ) - k_nope = m.kv_a_layernorm(k_nope) - # main stream waits for the completion of the event on the alt stream to ensure data dependency is complete - if q_event is not None: - torch.npu.current_stream().wait_event(q_event) + q_lora = q.clone() # required for topk_indices + k_nope, k_pe = latent_cache.unsqueeze(1).split( + [m.kv_lora_rank, m.qk_rope_head_dim], dim=-1 + ) + k_nope = m.kv_a_layernorm(k_nope) + q = m.q_b_proj(q_lora)[0].view(-1, m.num_local_heads, m.qk_head_dim) q_nope, q_pe = q.split([m.qk_nope_head_dim, m.qk_rope_head_dim], dim=-1) @@ -363,6 +388,11 @@ def forward_dsa_prepare_npu( q_nope_out = q_nope_out.transpose(0, 1) + if m.layer_id == 0: + m.rotary_emb.sin_cos_cache = m.rotary_emb.cos_sin_cache.index_select( + 0, positions + ) + q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe) if nsa_use_prefill_cp(forward_batch): diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 7d1511963191..02ef4e2440cd 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1254,22 +1254,48 @@ def forward_npu( and not forward_batch.forward_mode.is_draft_extend() ) - cos_sin = self.rotary_emb.cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) - sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) - bs = q_lora.shape[0] - if self.alt_stream is not None: - self.alt_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(self.alt_stream): + + if self.rotary_emb.is_neox_style: + if not hasattr(forward_batch, "npu_indexer_sin_cos_cache"): + cos_sin = self.rotary_emb.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim) + forward_batch.npu_indexer_sin_cos_cache = (sin, cos) + else: + sin, cos = forward_batch.npu_indexer_sin_cos_cache + + if self.alt_stream is not None: + self.alt_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(self.alt_stream): + q_lora = ( + (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora + ) + q = self.wq_b(q_lora)[ + 0 + ] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] + wq_b_event = self.alt_stream.record_event() + q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] + q_pe, q_nope = torch.split( + q, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64, 64 + 64] + q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view( + bs, self.n_heads, self.rope_head_dim + ) # [bs, n, d] + q = torch.cat([q_pe, q_nope], dim=-1) + q.record_stream(self.alt_stream) + q_rope_event = self.alt_stream.record_event() + else: q_lora = ( (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora ) q = self.wq_b(q_lora)[ 0 ] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] - wq_b_event = self.alt_stream.record_event() q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] q_pe, q_nope = torch.split( q, @@ -1281,9 +1307,52 @@ def forward_npu( bs, self.n_heads, self.rope_head_dim ) # [bs, n, d] q = torch.cat([q_pe, q_nope], dim=-1) - q.record_stream(self.alt_stream) - q_rope_event = self.alt_stream.record_event() + + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): + indexer_weight_stream = get_indexer_weight_stream() + indexer_weight_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(indexer_weight_stream): + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + weights.record_stream(indexer_weight_stream) + weights_event = indexer_weight_stream.record_event() + else: + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + + k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] + k = self.k_norm(k_proj) + if ( + _use_ag_after_qlora + and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED + and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL + ): + k = scattered_to_tp_attn_full(k, forward_batch) + k_pe, k_nope = torch.split( + k, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64 + 64] + + k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) + k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view( + bs, 1, self.rope_head_dim + ) # [bs, 1, d] + k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] + else: + if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): + indexer_weight_stream = get_indexer_weight_stream() + indexer_weight_stream.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(indexer_weight_stream): + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + weights.record_stream(indexer_weight_stream) + weights_event = indexer_weight_stream.record_event() + else: + x = x.view(-1, self.hidden_size) + weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + q_lora = (q_lora, dynamic_scale) if dynamic_scale is not None else q_lora q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128] q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128] @@ -1292,43 +1361,26 @@ def forward_npu( [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1, ) # [bs, 64, 64 + 64] - q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos, sin).view( - bs, self.n_heads, self.rope_head_dim - ) # [bs, n, d] - q = torch.cat([q_pe, q_nope], dim=-1) - if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): - indexer_weight_stream = get_indexer_weight_stream() - indexer_weight_stream.wait_stream(torch.npu.current_stream()) - with torch.npu.stream(indexer_weight_stream): - x = x.view(-1, self.hidden_size) - weights = self.weights_proj(x.float())[0].to(torch.bfloat16) - weights.record_stream(indexer_weight_stream) - weights_event = indexer_weight_stream.record_event() - else: - x = x.view(-1, self.hidden_size) - weights = self.weights_proj(x.float())[0].to(torch.bfloat16) + k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] + k = self.k_norm(k_proj) + k_pe, k_nope = torch.split( + k, + [self.rope_head_dim, self.head_dim - self.rope_head_dim], + dim=-1, + ) # [bs, 64 + 64] - k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128] - k = self.k_norm(k_proj) - if ( - _use_ag_after_qlora - and layer_scatter_modes.layer_input_mode == ScatterMode.SCATTERED - and layer_scatter_modes.attn_mode == ScatterMode.TP_ATTN_FULL - ): - k = scattered_to_tp_attn_full(k, forward_batch) - k_pe, k_nope = torch.split( - k, - [self.rope_head_dim, self.head_dim - self.rope_head_dim], - dim=-1, - ) # [bs, 64 + 64] - - k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim) - k_pe = torch.ops.npu.npu_rotary_mul(k_pe, cos, sin).view( - bs, 1, self.rope_head_dim - ) # [bs, 1, d] - k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128] + k_pe = k_pe.unsqueeze(1) + + if layer_id == 0: + self.rotary_emb.sin_cos_cache = ( + self.rotary_emb.cos_sin_cache.index_select(0, positions) + ) + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + k_pe = k_pe.squeeze(1) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe, k_nope], dim=-1) if ( is_prefill @@ -1394,7 +1446,7 @@ def forward_npu( past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id) - if self.alt_stream is not None: + if self.rotary_emb.is_neox_style and self.alt_stream is not None: torch.npu.current_stream().wait_event(q_rope_event) if envs.SGLANG_NPU_USE_MULTI_STREAM.get(): torch.npu.current_stream().wait_event(weights_event) diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 943fe8558f4f..a2cd99375924 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -39,6 +39,7 @@ if _is_npu: import torch_npu + from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa class RotaryEmbedding(MultiPlatformOp): @@ -202,9 +203,14 @@ def forward_native( if offsets is not None: positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) + + if hasattr(self, "sin_cos_cache"): + cos_sin = self.sin_cos_cache + else: + cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape @@ -236,8 +242,26 @@ def forward_npu( assert ( fused_set_kv_buffer_arg is None ), "fused_set_kv_buffer_arg is not supported for npu implementation" - if query.dtype == torch.bfloat16 and self.cos_sin_cache.dtype == torch.float: - return self.forward_native(positions, query, key, offsets) + if ( + query.dtype == torch.bfloat16 + and self.cos_sin_cache.dtype == torch.float + or key.ndim == 3 + ): + if hasattr(self, "sin_cos_cache"): + cos_sin = self.sin_cos_cache + else: + cos_sin = self.cos_sin_cache.index_select(0, positions) + + if query.shape[0] * query.shape[1] < 65535: + return fused_rope_qk_mqa( + query, + key, + cos_sin, + self.rotary_dim, + self.is_neox_style, + ) + else: + return self.forward_native(positions, query, key, offsets) if self.is_neox_style: rotary_mode = "half" else: diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index d57eb882296c..28029a0c75e9 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -15,9 +15,11 @@ """Inference-only DeepSeek NextN Speculative Decoding.""" import logging +import os from typing import Iterable, Optional, Tuple import torch +from safetensors.torch import load_file from torch import nn from transformers import PretrainedConfig @@ -99,6 +101,13 @@ def __init__( self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + self.rot_weight = None + if _is_npu: + rot_weight_path = get_global_server_args().model_path + "/rot.safetensors" + if os.path.isfile(rot_weight_path): + self.rot_weight = load_file(rot_weight_path) + self.rot_weight = self.rot_weight["rot.weight"].npu() + self.alt_stream = ( torch.cuda.Stream() if _is_cuda or envs.SGLANG_NPU_USE_MULTI_STREAM.get() @@ -112,6 +121,7 @@ def __init__( ): layer_name = "layers." + str(config.num_hidden_layers) + self.quant_config = quant_config self.decoder = DeepseekV2DecoderLayer( config, 0, @@ -137,6 +147,9 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + if _is_npu and self.quant_config is None: + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0" zero_allocator = BumpAllocator( buffer_size=2, dtype=torch.float32, @@ -155,7 +168,13 @@ def forward( torch.cat( ( self.enorm(hidden_states), - self.hnorm(forward_batch.spec_info.hidden_states), + self.hnorm( + forward_batch.spec_info.hidden_states + if self.rot_weight is None + else torch.matmul( + forward_batch.spec_info.hidden_states, self.rot_weight + ) + ), ), dim=-1, ) @@ -189,6 +208,9 @@ def forward( torch.cuda.current_stream(), ) + if _is_npu and self.quant_config is None: + os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0" + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" return hidden_states