Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/backend/attention_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |

Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
Expand Down Expand Up @@ -46,3 +47,8 @@ python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```

- Ascend
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
```
219 changes: 219 additions & 0 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner


@dataclass
class ForwardMetadata:

# calculated map for kv positions [bs * maxseqlen]
block_tables: Optional[torch.Tensor] = None

# seq len inputs
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None


class AscendAttnBackend(AttentionBackend):

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = ForwardMetadata()
self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.native_attn = TorchNativeAttnBackend(model_runner)

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
][:, :: self.page_size]
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()

def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

if not self.use_mla:
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
output = torch.empty(
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=output,
)
return output
else:
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)

use_gqa = layer.tp_q_head_num != layer.tp_k_head_num

q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)

causal = True
if (
layer.is_cross_attention
or layer.attn_type == AttentionType.ENCODER_ONLY
):
causal = False

self.native_attn._run_sdpa_forward_extend(
q_,
o_,
k_cache.view(
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
),
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=causal,
)
return o

def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output,
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
num_tokens = query.shape[0]
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)

attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mla_vheadsize=self.kv_lora_rank,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import einops
import torch
from sgl_kernel import silu_and_mul
from torch.nn import Module

from sglang.srt.custom_op import CustomOp
Expand Down Expand Up @@ -50,13 +49,18 @@
dispose_tensor,
get_bool_env_var,
is_hip,
is_npu,
set_weight_attrs,
)

_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if not _is_npu:
from sgl_kernel import silu_and_mul

if _is_hip:
from vllm._custom_ops import scaled_fp8_quant

Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,44 @@ def forward_cpu(
routed_scaling_factor,
)

def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)

def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")

Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
is_cpu,
is_cuda,
is_hip,
is_npu,
)

_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_npu = is_npu()

if _is_cuda:
from sgl_kernel import moe_fused_gate
Expand Down Expand Up @@ -159,6 +161,9 @@ def grouped_topk_gpu(
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"

scores = torch.softmax(gating_output, dim=-1)
# NPU compiler limitation
if _is_npu and scores.dtype == torch.bfloat16:
scores = scores.to(torch.float16)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def __init__(
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
device: Optional[str] = "cuda" if not _is_npu else "npu",
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
Expand All @@ -679,7 +679,7 @@ def __init__(
)

# Re-dispatch
if _is_hip:
if _is_hip or _is_npu:
self._forward_method = self.forward_native

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,7 @@ def get_model_worker_batch(
)
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["attention_backend"] == "ascend"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = (
Expand Down Expand Up @@ -1875,7 +1876,10 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if global_server_args_dict["attention_backend"] != "torch_native":
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
Expand Down
Loading
Loading