Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,13 +761,10 @@ def cp_all_gather_into_tensor_async(
stream is not None
), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)"
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None:
pynccl_comm.cp_all_gather_into_tensor(output, input, stream=stream)
if pynccl_comm is None or pynccl_comm.disabled:
self.all_gather_into_tensor(output, input)
else:
logger.warning("not all_gather_into_tensor_async")
torch.ops.sglang.reg_all_gather_into_tensor(
output, input, group_name=self.unique_name
)
pynccl_comm.cp_all_gather_into_tensor(output, input, stream=stream)

def all_gather(
self,
Expand Down
133 changes: 111 additions & 22 deletions python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
is_mla_preprocess_enabled,
)
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
Expand Down Expand Up @@ -42,6 +43,7 @@ class ForwardMetadata:
seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
actual_seq_lengths_kv: Optional[torch.Tensor] = None

# prefix cache
prefix_lens: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -267,7 +269,6 @@ def update_verify_buffers_to_fill_after_draft(

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata()
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
Expand Down Expand Up @@ -411,6 +412,72 @@ def init_forward_metadata_replay_cuda_graph(
def get_cuda_graph_seq_len_fill_value(self):
return 0

def do_cp_balance_attn(
self,
q_nope,
k_nope,
q_pe,
k_pe,
topk_indices,
layer,
actual_seq_qlen,
actual_seq_lengths_kv,
):
seq_len = q_nope.shape[0]
split_len = (seq_len + 1) // 2
q_nope_prev, q_nope_next = torch.split(q_nope, split_len, dim=0)
q_rope_prev, q_rope_next = torch.split(q_pe, split_len, dim=0)
q_nope_prev = q_nope_prev.contiguous()
q_nope_next = q_nope_next.contiguous()
q_rope_prev = q_rope_prev.contiguous()
q_rope_next = q_rope_next.contiguous()
topk_indices_prev, topk_indices_next = topk_indices

actual_seq_qlen_prev, actual_seq_qlen_next = actual_seq_qlen
actual_seq_lengths_kv_prev, actual_seq_lengths_kv_next = actual_seq_lengths_kv

attn_out_prev = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope_prev,
key=k_nope,
value=k_nope,
query_rope=q_rope_prev,
key_rope=k_pe,
sparse_indices=topk_indices_prev,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen_prev.to(
device=q_nope.device, dtype=torch.int32
),
actual_seq_lengths_kv=actual_seq_lengths_kv_prev.to(
device=q_nope.device, dtype=torch.int32
),
block_table=self.forward_metadata.block_tables,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
attn_out_next = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope_next,
key=k_nope,
value=k_nope,
query_rope=q_rope_next,
key_rope=k_pe,
sparse_indices=topk_indices_next,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen_next.to(
device=q_nope.device, dtype=torch.int32
),
actual_seq_lengths_kv=actual_seq_lengths_kv_next.to(
device=q_nope.device, dtype=torch.int32
),
block_table=self.forward_metadata.block_tables,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return torch.cat([attn_out_prev, attn_out_next], dim=0)

def forward_sparse(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -440,9 +507,12 @@ def forward_sparse(
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables

if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
if self.forward_metadata.actual_seq_lengths_q is not None:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
else:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
if (
Expand Down Expand Up @@ -471,27 +541,46 @@ def forward_sparse(
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:

if self.forward_metadata.actual_seq_lengths_kv is not None:
actual_seq_lengths_kv = self.forward_metadata.actual_seq_lengths_kv
elif self.forward_metadata.seq_lens_cpu_int is not None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens

attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
if (
is_prefill
and is_nsa_enable_prefill_cp()
and forward_batch.nsa_cp_metadata is not None
):
attn_out = self.do_cp_balance_attn(
q_nope,
k_nope,
q_pe,
k_pe,
topk_indices,
layer,
actual_seq_qlen,
actual_seq_lengths_kv,
)
else:
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen,
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=self.forward_metadata.block_tables,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)

return attn_out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,17 @@ def forward_dsa_prepare_npu(

q_nope_out = q_nope_out.transpose(0, 1)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The result of enable_prefill_cp is used again on line 319. Consider caching it in a local variable to avoid re-computation and improve readability. For example: use_prefill_cp = enable_prefill_cp(...) and then use use_prefill_cp in the if conditions.

positions = cp_split_and_rebuild_position(forward_batch, positions)

q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
# support allgather+rerrange
k_nope, k_pe = m.rebuild_cp_kv_cache(
latent_cache, forward_batch, k_nope, k_pe
)

topk_indices = m.indexer(
hidden_states, q_lora, positions, forward_batch, m.layer_id
)
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/hardware_backend/npu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def init_npu_backend():
assert _is_npu, "NPU backend initialization called on non-NPU device."

import sgl_kernel_npu # noqa: F401

try:
import custom_ops # noqa: F401
except ImportError:
logger.warning(
f"custom_ops not found, dsv3.2 requires this package, which includes the npu_lightning_indexer and npu_sparse_flash_attention operators."
)

import torch_npu
from torch_npu.contrib import transfer_to_npu # noqa: F401

Expand Down
Loading
Loading