diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 817ae43fbed2..e772cd791b3e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1583,15 +1583,21 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None - # HiSparse requires flashmla_sparse for both prefill and decode + # HiSparse: BF16 KV -> flashmla_sparse (native BF16 sparse). + # FP8 KV -> flashmla_kv (native FP8 + sparse via is_fp8_kvcache=True + indices=...). + # flashmla_sparse does not accept FP8, and flashmla_kv does not accept BF16 sparse, + # so the KV dtype determines the backend when the user does not override. if self.enable_hisparse: + hisparse_default_backend = ( + "flashmla_kv" if kv_cache_dtype == "fp8_e4m3" else "flashmla_sparse" + ) if not user_set_prefill: - self.nsa_prefill_backend = "flashmla_sparse" + self.nsa_prefill_backend = hisparse_default_backend if not user_set_decode: - self.nsa_decode_backend = "flashmla_sparse" + self.nsa_decode_backend = hisparse_default_backend logger.warning( - f"HiSparse enabled: using flashmla_sparse NSA backends " - f"(prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend})." + f"HiSparse enabled ({kv_cache_dtype}): using NSA backends " + f"prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend}." ) return @@ -6769,24 +6775,32 @@ def check_server_args(self): assert ( self.disable_radix_cache ), "Hierarchical sparse attention currently requires --disable-radix-cache." + if self.kv_cache_dtype not in ("bfloat16", "auto", "fp8_e4m3"): + raise ValueError( + f"HiSparse requires bfloat16 or fp8_e4m3 KV cache, " + f"but got --kv-cache-dtype={self.kv_cache_dtype}. " + f"Please use --kv-cache-dtype=bfloat16 or fp8_e4m3." + ) + + # Backend/dtype pairing: flashmla_sparse only takes BF16 KV; + # flashmla_kv only supports FP8 (it always reads KV as FP8 via + # is_fp8_kvcache=True, inline-quantizing BF16 would defeat HiSparse). + allowed_backends_for_dtype = { + "bfloat16": {"flashmla_sparse"}, + "fp8_e4m3": {"flashmla_kv"}, + }.get(self.kv_cache_dtype, {"flashmla_sparse", "flashmla_kv"}) for attr, label in [ ("nsa_prefill_backend", "prefill"), ("nsa_decode_backend", "decode"), ]: backend = getattr(self, attr) - if backend is not None and backend != "flashmla_sparse": + if backend is not None and backend not in allowed_backends_for_dtype: raise ValueError( - f"HiSparse requires flashmla_sparse NSA {label} backend, " - f"but got --nsa-{label}-backend={backend}. " - f"Please use --nsa-{label}-backend=flashmla_sparse or omit it." + f"HiSparse with --kv-cache-dtype={self.kv_cache_dtype} requires " + f"--nsa-{label}-backend in {sorted(allowed_backends_for_dtype)}, " + f"but got {backend}." ) - if self.kv_cache_dtype != "bfloat16": - raise ValueError( - f"HiSparse requires bfloat16 KV cache, but got --kv-cache-dtype={self.kv_cache_dtype}. " - f"Please use --kv-cache-dtype=bfloat16." - ) - assert ( self.schedule_conservativeness >= 0 ), "schedule_conservativeness must be non-negative"