diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index ccbeeb453a7b..4eb4fc923b78 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -21,6 +21,7 @@ ) from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput @@ -210,6 +211,7 @@ def __init__(self, model_runner: ModelRunner): self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size + self.model_dtype = model_runner.model_config.dtype self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank @@ -253,6 +255,18 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask + # head num padding + self.padding_size_list = [1, 2, 4, 8, 16, 32, 64, 128] + self.q_head_num_padding = None + if hasattr(model_runner.model_config, "num_attention_heads") and self.use_mla: + self.tp_q_head_num = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + for num in self.padding_size_list: + if num >= self.tp_q_head_num: + self.q_head_num_padding = num + break + # dllm model config self.dllm_config = DllmConfig.from_server_args(model_runner.server_args) self.is_dllm_model = False @@ -406,6 +420,37 @@ def init_forward_metadata_capture_cuda_graph( torch.cumsum(extend_seq_lens_cpu_int, dim=0).int().tolist() ) + if ( + self.q_head_num_padding is not None + and self.q_head_num_padding > self.tp_q_head_num + ): + # In the MLA architecture, the FIA kernel requires the head count to be a power of 2. + # Therefore, we pad the head dimension accordingly and initialize an empty tensor for padding. + metadata.nope_padding = torch.empty( + [ + bs, + 1, + self.q_head_num_padding - self.tp_q_head_num, + self.kv_lora_rank, + ], + dtype=( + self.model_dtype if self.model_dtype is not None else torch.bfloat16 + ), + device=seq_lens.device, + ) + metadata.rope_padding = torch.empty( + [ + bs, + 1, + self.q_head_num_padding - self.tp_q_head_num, + self.qk_rope_head_dim, + ], + dtype=( + self.model_dtype if self.model_dtype is not None else torch.bfloat16 + ), + device=seq_lens.device, + ) + self.graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -946,110 +991,212 @@ def forward_extend( -1, layer.tp_q_head_num * layer.v_head_dim ) elif sum(forward_batch.extend_prefix_lens_cpu) > 0: - num_token_padding = q.shape[0] - q, k, v = [ - data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] - ] - q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) - k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) - - # 1st, compute extend tokens to get attn_output and attn_lse - num_tokens = q_nope.size(0) - attn_output = torch.zeros( - num_tokens, - layer.tp_q_head_num, - layer.v_head_dim, - dtype=q_nope.dtype, - device=q_nope.device, - ) - attn_lse = torch.zeros( - layer.tp_q_head_num, - num_tokens, - dtype=torch.float32, - device=q_nope.device, - ) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_rope, - k_nope=k_nope, - k_rope=k_rope, - value=v, - mask=self.ringmla_mask, - seqlen=self.forward_metadata.extend_seq_lens_cpu_int, - head_num=layer.tp_q_head_num, - kv_head_num=layer.tp_k_head_num, - pre_out=None, - prev_lse=None, - qk_scale=layer.scaling, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse, - ) + # This branch adds support for prefix cache for GLM-4.7-Flash. + # When using the MLA architecture, if qk head dim equals v head dim and the head count is not a power of 2, + # we use the FIA kernel for computation. + if layer.qk_head_dim == layer.v_head_dim: + q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) - # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - kv_cached = torch.index_select( - k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables - ) - k_rope_cached = torch.index_select( - v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables - ).flatten(0, 1) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_buffer = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ) + kv_cached = torch.index_select( + k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ) + k_rope_cached = torch.index_select( + v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ).flatten(0, 1) - assert layer.kv_b_proj is not None - kv = layer.kv_b_proj(kv_cached)[0].view( - -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim - ) - k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) + assert layer.kv_b_proj is not None + kv = layer.kv_b_proj(kv_cached)[0].view( + -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim + ) + k_nope, v_pre = kv.split( + [self.qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) - # 3rd, compute history kv to attn_out - k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) - seq_len = torch.stack( - [ + k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) + k_pre = torch.cat([k_nope, k_rope], dim=-1) + + attn_output = torch.empty( + (q.size(0), layer.tp_q_head_num, layer.v_head_dim), + device=q.device, + dtype=q.dtype, + ) + q_len_offset = 0 + prefix_len_offset = 0 + for q_len, prefix_len in zip( self.forward_metadata.extend_seq_lens_cpu_int, self.forward_metadata.prefix_lens, + ): + k_cur_slice = k[None, q_len_offset : q_len_offset + q_len] + v_cur_slice = v[None, q_len_offset : q_len_offset + q_len] + k_pre_slice = k_pre[ + None, prefix_len_offset : prefix_len_offset + prefix_len + ] + v_pre_slice = v_pre[ + None, prefix_len_offset : prefix_len_offset + prefix_len + ] + + k_full = torch.cat([k_pre_slice, k_cur_slice], dim=1) + v_full = torch.cat([v_pre_slice, v_cur_slice], dim=1) + + attn_output[q_len_offset : q_len_offset + q_len] = ( + torch.ops.npu.npu_fused_infer_attention_score( + q[None, q_len_offset : q_len_offset + q_len], + k_full, + v_full, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", # todo, TND not supports q_heads!=k_heads + atten_mask=self.fia_mask, + sparse_mode=3, + scale=layer.scaling, + next_tokens=0, + )[0] + ) + q_len_offset += q_len + prefix_len_offset += prefix_len + attn_output = attn_output.view( + -1, layer.tp_q_head_num * layer.v_head_dim + ) + else: + num_token_padding = q.shape[0] + q, k, v = [ + data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] ] - ) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_rope, - k_nope=k_nope, - k_rope=k_rope, - value=v, - mask=self.ringmla_mask, - seqlen=seq_len, - head_num=layer.tp_q_head_num, - kv_head_num=layer.tp_k_head_num, - pre_out=attn_output, - prev_lse=attn_lse, - qk_scale=layer.scaling, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - calc_type="calc_type_default", - output=attn_output, - softmax_lse=attn_lse, - ) - attn_output = attn_output.reshape( - [-1, layer.tp_q_head_num, layer.v_head_dim] - ) - if num_token_padding != forward_batch.num_token_non_padded_cpu: - attn_output = torch.cat( + q_nope, q_rope = q.split( + [layer.v_head_dim, self.qk_rope_head_dim], dim=-1 + ) + k_nope, k_rope = k.split( + [layer.v_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # 1st, compute extend tokens to get attn_output and attn_lse + num_tokens = q_nope.size(0) + attn_output = torch.zeros( + num_tokens, + layer.tp_q_head_num, + layer.v_head_dim, + dtype=q_nope.dtype, + device=q_nope.device, + ) + attn_lse = torch.zeros( + layer.tp_q_head_num, + num_tokens, + dtype=torch.float32, + device=q_nope.device, + ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=self.forward_metadata.extend_seq_lens_cpu_int, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=None, + prev_lse=None, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse, + ) + + # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_buffer = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ) + kv_cached = torch.index_select( + k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ) + k_rope_cached = torch.index_select( + v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables + ).flatten(0, 1) + + assert layer.kv_b_proj is not None + kv = layer.kv_b_proj(kv_cached)[0].view( + -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim + ) + k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) + + # 3rd, compute history kv to attn_out + k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) + seq_len = torch.stack( [ - attn_output, - attn_output.new_zeros( - num_token_padding - attn_output.shape[0], - *attn_output.shape[1:], - ), - ], - dim=0, + self.forward_metadata.extend_seq_lens_cpu_int, + self.forward_metadata.prefix_lens, + ] ) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_rope, + k_nope=k_nope, + k_rope=k_rope, + value=v, + mask=self.ringmla_mask, + seqlen=seq_len, + head_num=layer.tp_q_head_num, + kv_head_num=layer.tp_k_head_num, + pre_out=attn_output, + prev_lse=attn_lse, + qk_scale=layer.scaling, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + calc_type="calc_type_default", + output=attn_output, + softmax_lse=attn_lse, + ) + attn_output = attn_output.reshape( + [-1, layer.tp_q_head_num, layer.v_head_dim] + ) + if num_token_padding != forward_batch.num_token_non_padded_cpu: + attn_output = torch.cat( + [ + attn_output, + attn_output.new_zeros( + num_token_padding - attn_output.shape[0], + *attn_output.shape[1:], + ), + ], + dim=0, + ) else: - assert ( - layer.qk_head_dim != layer.v_head_dim - ), "FIA only supports qk_head_dim != v_head_dim" - if layer.v_head_dim in [256]: + if layer.qk_head_dim == layer.v_head_dim: + """FIA will support multi-bs in the later version of CANN""" + q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) + attn_output = torch.empty( + (q.size(0), layer.tp_q_head_num, layer.v_head_dim), + device=q.device, + dtype=q.dtype, + ) + q_len_offset = 0 + for q_len in forward_batch.extend_seq_lens_cpu: + attn_output[q_len_offset : q_len_offset + q_len] = ( + torch.ops.npu.npu_fused_infer_attention_score( + q[None, q_len_offset : q_len_offset + q_len], + k[None, q_len_offset : q_len_offset + q_len], + v[None, q_len_offset : q_len_offset + q_len], + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSND", # todo, TND not supports q_heads!=k_heads + atten_mask=self.fia_mask.unsqueeze(0), + sparse_mode=3 if q_len != 1 else 0, + scale=layer.scaling, + next_tokens=0, + )[0] + ) + q_len_offset += q_len + attn_output = attn_output.view( + -1, layer.tp_q_head_num * layer.v_head_dim + ) + elif layer.v_head_dim in [256]: """Currently, in NO_QUANT situation, qk_nope_head_dim == v_head_dim, and rope exists, v_head_dim only support 512 and 128""" kv_lora_rank = k.shape[-1] - self.qk_rope_head_dim kv_c, k_rope = k.split([kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -1477,6 +1624,24 @@ def forward_decode_graph( q_nope = q.view(-1, 1, layer.tp_q_head_num, self.kv_lora_rank).contiguous() q_rope = q_rope.view(-1, 1, layer.tp_q_head_num, self.qk_rope_head_dim) + assert ( + self.q_head_num_padding is None + or self.q_head_num_padding >= layer.tp_q_head_num + ) + + if ( + self.q_head_num_padding is not None + and self.q_head_num_padding > layer.tp_q_head_num + ): + # The FIA kernel only supports head counts that are powers of 2. + # Therefore, we pad the head dimension when it is not a power of 2. + q_nope = torch.cat( + [q_nope, self.forward_metadata.nope_padding], dim=2 + ).contiguous() + q_rope = torch.cat( + [q_rope, self.forward_metadata.rope_padding], dim=2 + ).contiguous() + if self.forward_metadata.seq_lens_cpu_int is None: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list else: @@ -1490,7 +1655,7 @@ def forward_decode_graph( c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, - num_heads=layer.tp_q_head_num, + num_heads=self.q_head_num_padding, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, @@ -1510,7 +1675,7 @@ def forward_decode_graph( c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, - num_heads=layer.tp_q_head_num, + num_heads=self.q_head_num_padding, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, @@ -1523,6 +1688,8 @@ def forward_decode_graph( workspace=workspace, out=[output, softmax_lse], ) + + output = output[:, :, : layer.tp_q_head_num, :] return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) def forward_decode( diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 943fe8558f4f..0f2e7c6c402c 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -242,7 +242,16 @@ def forward_npu( rotary_mode = "half" else: rotary_mode = "interleave" + mrope_section = [0, 0, 0] + # The npu_mrope kernel only supports 1D or 2D tensors for query and key. + # Therefore, when their dimensions exceed 2D, we flatten query and key to 2D tensors before computation + # and reshape their original shapes afterward. + query_shape = query.shape + key_shape = key.shape + query = query.reshape(query.shape[0], -1) + key = key.reshape(key.shape[0], -1) + query_out, key_out = torch_npu.npu_mrope( positions, query, @@ -252,6 +261,9 @@ def forward_npu( mrope_section=mrope_section, rotary_mode=rotary_mode, ) + + query_out = query_out.reshape(query_shape) + key_out = key_out.reshape(key_shape) return query_out, key_out def forward_cpu(