Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,15 +1583,21 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str:
user_set_prefill = self.nsa_prefill_backend is not None
user_set_decode = self.nsa_decode_backend is not None

# HiSparse requires flashmla_sparse for both prefill and decode
# HiSparse: BF16 KV -> flashmla_sparse (native BF16 sparse).
# FP8 KV -> flashmla_kv (native FP8 + sparse via is_fp8_kvcache=True + indices=...).
# flashmla_sparse does not accept FP8, and flashmla_kv does not accept BF16 sparse,
# so the KV dtype determines the backend when the user does not override.
if self.enable_hisparse:
hisparse_default_backend = (
"flashmla_kv" if kv_cache_dtype == "fp8_e4m3" else "flashmla_sparse"
)
if not user_set_prefill:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_prefill_backend = hisparse_default_backend
if not user_set_decode:
self.nsa_decode_backend = "flashmla_sparse"
self.nsa_decode_backend = hisparse_default_backend
logger.warning(
f"HiSparse enabled: using flashmla_sparse NSA backends "
f"(prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend})."
f"HiSparse enabled ({kv_cache_dtype}): using NSA backends "
f"prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend}."
)
return

Expand Down Expand Up @@ -6769,24 +6775,32 @@ def check_server_args(self):
assert (
self.disable_radix_cache
), "Hierarchical sparse attention currently requires --disable-radix-cache."
if self.kv_cache_dtype not in ("bfloat16", "auto", "fp8_e4m3"):
raise ValueError(
f"HiSparse requires bfloat16 or fp8_e4m3 KV cache, "
f"but got --kv-cache-dtype={self.kv_cache_dtype}. "
f"Please use --kv-cache-dtype=bfloat16 or fp8_e4m3."
)

# Backend/dtype pairing: flashmla_sparse only takes BF16 KV;
# flashmla_kv only supports FP8 (it always reads KV as FP8 via
# is_fp8_kvcache=True, inline-quantizing BF16 would defeat HiSparse).
allowed_backends_for_dtype = {
"bfloat16": {"flashmla_sparse"},
"fp8_e4m3": {"flashmla_kv"},
}.get(self.kv_cache_dtype, {"flashmla_sparse", "flashmla_kv"})
for attr, label in [
("nsa_prefill_backend", "prefill"),
("nsa_decode_backend", "decode"),
]:
backend = getattr(self, attr)
if backend is not None and backend != "flashmla_sparse":
if backend is not None and backend not in allowed_backends_for_dtype:
raise ValueError(
f"HiSparse requires flashmla_sparse NSA {label} backend, "
f"but got --nsa-{label}-backend={backend}. "
f"Please use --nsa-{label}-backend=flashmla_sparse or omit it."
f"HiSparse with --kv-cache-dtype={self.kv_cache_dtype} requires "
f"--nsa-{label}-backend in {sorted(allowed_backends_for_dtype)}, "
f"but got {backend}."
)

if self.kv_cache_dtype != "bfloat16":
raise ValueError(
f"HiSparse requires bfloat16 KV cache, but got --kv-cache-dtype={self.kv_cache_dtype}. "
f"Please use --kv-cache-dtype=bfloat16."
)

assert (
self.schedule_conservativeness >= 0
), "schedule_conservativeness must be non-negative"
Expand Down
Loading