Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/diffusion/performance/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ When using the diffusers backend, `--attention-backend` is passed through to dif
- **CUDA**: prefers FlashAttention (FA3/FA4) when supported; otherwise falls back to PyTorch SDPA.
- **ROCm**: uses FlashAttention when available; otherwise falls back to PyTorch SDPA.
- **MPS**: always uses PyTorch SDPA.
- **NPU**: always uses PyTorch SDPA.
- **NPU**: for ring attention uses FA otherwise uses PyTorch SDPA.

## Backend options

Expand Down Expand Up @@ -87,7 +87,7 @@ Some backends require additional configuration. You can pass these parameters vi

| Backend | CUDA | ROCm | MPS | NPU | Notes |
|---|---:|---:|---:|---:|---|
| `fa` | ✅ | ✅ | ❌ | | CUDA requires SM80+ and fp16/bf16. FlashAttention is only used when the required runtime is installed; otherwise it falls back to `torch_sdpa`. |
| `fa` | ✅ | ✅ | ❌ | | CUDA requires SM80+ and fp16/bf16. FlashAttention is only used when the required runtime is installed; otherwise it falls back to `torch_sdpa`. No extra installations are required for NPU |
| `torch_sdpa` | ✅ | ✅ | ✅ | ✅ | Most compatible option across platforms. |
| `sliding_tile_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only. Requires `st_attn`. Configure via `--attention-backend-config`. |
| `sage_attn` | ✅ | ❌ | ❌ | ❌ | CUDA-only (optional dependency). |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from dataclasses import dataclass
from typing import Any

import torch

from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


@dataclass
class AscendFAMetadata:
pass


class AscendFAMetadataBuilder(AttentionMetadataBuilder):
def __init__(self) -> None:
pass

def prepare(self) -> None:
pass

def build(
self,
**kwargs: dict[str, Any],
) -> AttentionMetadata:
return AscendFAMetadata()


class AscendFABackend(AttentionBackend):

@staticmethod
def get_enum() -> AttentionBackendEnum:
return AttentionBackendEnum.FA

@staticmethod
def get_impl_cls() -> type["AscendFAImpl"]:
return AscendFAImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError

@staticmethod
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
return AscendFAMetadataBuilder


class AscendFAImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
return_softmax_lse: bool = False,
) -> torch.Tensor:
mask = None
if self.causal:
seq_len = query.shape[1]
mask = torch.triu(
torch.ones(seq_len, seq_len, device=query.device), diagonal=1
).bool()
# transpose to bs, heads, seq_len, head_dim
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output, lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale=self.softmax_scale,
input_layout="BNSD",
softmax_lse_flag=return_softmax_lse,
atten_mask=mask,
)
output = output.transpose(1, 2)
if return_softmax_lse:
return output, lse
return output
4 changes: 4 additions & 0 deletions python/sglang/multimodal_gen/runtime/platforms/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def get_attn_backend_cls_str(
head_size: int,
dtype: torch.dtype,
) -> str:
if selected_backend == AttentionBackendEnum.FA:
logger.info("Using Ascend Flash Attention backend.")
return "sglang.multimodal_gen.runtime.layers.attention.backends.ascend_fa.AscendFABackend"

logger.info("Using Torch SDPA backend.")
return (
"sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,71 @@
"expected_e2e_ms": 91733.92,
"expected_avg_denoise_ms": 2091.33,
"expected_median_denoise_ms": 2090.72
},
"qwen_image_t2i_2npu": {
"stages_ms": {
"InputValidationStage": 0.07,
"TextEncodingStage": 629.24,
"LatentPreparationStage": 0.69,
"TimestepPreparationStage": 35.29,
"DenoisingStage": 30529.83,
"DecodingStage": 74.25
},
"denoise_step_ms": {
"0": 477.43,
"1": 511.96,
"2": 607.78,
"3": 615.12,
"4": 616.29,
"5": 614.61,
"6": 623.04,
"7": 607.12,
"8": 615.32,
"9": 615.47,
"10": 616.93,
"11": 623.26,
"12": 607.12,
"13": 615.48,
"14": 615.07,
"15": 614.83,
"16": 623.18,
"17": 609.0,
"18": 614.8,
"19": 623.08,
"20": 607.64,
"21": 614.2,
"22": 615.58,
"23": 615.43,
"24": 623.59,
"25": 606.57,
"26": 616.02,
"27": 615.48,
"28": 615.76,
"29": 623.13,
"30": 608.73,
"31": 615.04,
"32": 616.08,
"33": 616.59,
"34": 623.77,
"35": 608.0,
"36": 616.1,
"37": 615.79,
"38": 615.34,
"39": 617.43,
"40": 610.99,
"41": 614.22,
"42": 623.27,
"43": 606.98,
"44": 615.87,
"45": 615.99,
"46": 614.66,
"47": 622.93,
"48": 607.97,
"49": 614.69
},
"expected_e2e_ms": 34362.34,
"expected_avg_denoise_ms": 610.41,
"expected_median_denoise_ms": 615.39
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@
),
T2I_sampling_params,
),
DiffusionTestCase(
"qwen_image_t2i_2npu",
DiffusionServerArgs(
model_path="/root/.cache/modelscope/hub/models/Qwen/Qwen-Image",
modality="image",
num_gpus=2,
# test ring attn
ulysses_degree=1,
ring_degree=2,
),
T2I_sampling_params,
),
]

EIGHT_NPU_CASES: list[DiffusionTestCase] = [
Expand Down
Loading