Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .github/workflows/pr-test-npu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: PR Test (Ascend NPU)

on:
push:
branches: [ main ]
paths:
- "python/**"
- "scripts/**"
- "test/**"
- ".github/workflows/pr-test-npu.yml"
pull_request:
branches: [ main ]
paths:
- "python/**"
- "scripts/**"
- "test/**"
- ".github/workflows/pr-test-npu.yml"
workflow_dispatch:

concurrency:
group: pr-test-npu-${{ github.ref }}
cancel-in-progress: true

jobs:
unit-test-backend-1-npu-ascend:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'npu')
strategy:
fail-fast: false
runs-on: self-hosted
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Run test
timeout-minutes: 40
run: |
cd test/srt
python3 run_suite.py --suite per-commit-npu

finish:
if: always()
needs: [ unit-test-backend-1-npu-ascend ]
runs-on: self-hosted
steps:
- name: Check all dependent job statuses
run: |
results=(${{ join(needs.*.result, ' ') }})
for result in "${results[@]}"; do
if [ "$result" = "failure" ] || [ "$result" = "cancelled" ]; then
echo "Job failed with result: $result"
exit 1
fi
done
echo "All jobs completed successfully"
exit 0
6 changes: 6 additions & 0 deletions docs/backend/attention_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |

Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
Expand Down Expand Up @@ -46,3 +47,8 @@ python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```

- Ascend
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
```
146 changes: 146 additions & 0 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention

from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner


@dataclass
class ForwardMetadata:

# calculated map for kv positions [bs * maxseqlen]
block_tables: Optional[torch.Tensor] = None

# seq len inputs
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None


class AscendAttnBackend(AttentionBackend):

def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len

def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = ForwardMetadata()
self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
][:, :: self.page_size]
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()

def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
output = torch.empty(
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=output,
)
return output

def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)

k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)

query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)

torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output,
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
6 changes: 5 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,7 @@ def get_model_worker_batch(
)
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["attention_backend"] == "ascend"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = (
Expand Down Expand Up @@ -1874,7 +1875,10 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if global_server_args_dict["attention_backend"] != "torch_native":
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
Expand Down
Loading
Loading