From 130930ca5fdc67b8b43dc7492f3ddc96f6463eb4 Mon Sep 17 00:00:00 2001 From: thomawan Date: Sat, 4 Apr 2026 12:42:19 +0800 Subject: [PATCH 1/2] Fix mi300 quant code path --- python/sglang/srt/mem_cache/memory_pool.py | 61 +++++++++++----------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index e4c158cda9f6..8cbff023179b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1575,37 +1575,38 @@ def set_mla_kv_buffer( ): layer_id = layer.layer_id - if self.nsa_kv_cache_store_fp8: - if _is_hip: - # HIP FP8 path uses raw MLA KV layout (nope + rope) without per-block scales. - # Fuse BF16/FP16 -> FP8 cast with paged KV write. - fp8_dtype = ( - torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn - ) - set_mla_kv_buffer_triton_fp8_quant( - self.kv_buffer[layer_id - self.start_layer], - loc, - cache_k_nope, - cache_k_rope, - fp8_dtype, - ) - else: - # OPTIMIZATION: Quantize k_nope and k_rope separately to avoid concat overhead - # This also enables reuse of set_mla_kv_buffer_triton two-tensor write path - # quantize_k_cache_separate returns (nope_part, rope_part) as uint8 bytes - cache_k_nope_fp8, cache_k_rope_fp8 = quantize_k_cache_separate( - cache_k_nope, cache_k_rope - ) + if ( + _is_hip + and self.use_nsa + and self.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) + ): + # HIP FP8 path uses raw MLA KV layout (nope + rope) without per-block scales. + # Fuse BF16/FP16 -> FP8 cast with paged KV write. + fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn + set_mla_kv_buffer_triton_fp8_quant( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope, + cache_k_rope, + fp8_dtype, + ) + elif self.nsa_kv_cache_store_fp8: + # OPTIMIZATION: Quantize k_nope and k_rope separately to avoid concat overhead + # This also enables reuse of set_mla_kv_buffer_triton two-tensor write path + # quantize_k_cache_separate returns (nope_part, rope_part) as uint8 bytes + cache_k_nope_fp8, cache_k_rope_fp8 = quantize_k_cache_separate( + cache_k_nope, cache_k_rope + ) - # Reuse existing two-tensor write kernel (works with FP8 byte layout) - # cache_k_nope_fp8: (num_tokens, 1, 528) uint8 [nope_fp8(512) | scales(16)] - # cache_k_rope_fp8: (num_tokens, 1, 128) uint8 [rope_bf16_bytes(128)] - set_mla_kv_buffer_triton( - self.kv_buffer[layer_id - self.start_layer], - loc, - cache_k_nope_fp8, - cache_k_rope_fp8, - ) + # Reuse existing two-tensor write kernel (works with FP8 byte layout) + # cache_k_nope_fp8: (num_tokens, 1, 528) uint8 [nope_fp8(512) | scales(16)] + # cache_k_rope_fp8: (num_tokens, 1, 128) uint8 [rope_bf16_bytes(128)] + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope_fp8, + cache_k_rope_fp8, + ) else: if cache_k_nope.dtype != self.dtype: cache_k_nope = cache_k_nope.to(self.dtype) From e795cd06a9e4c373a57ed7c54af7d255441b07f2 Mon Sep 17 00:00:00 2001 From: thomawan Date: Wed, 8 Apr 2026 11:34:45 +0800 Subject: [PATCH 2/2] Use fp8 dtype from import --- python/sglang/srt/mem_cache/memory_pool.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 8cbff023179b..c385ec0536d0 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -45,7 +45,7 @@ quantize_k_cache, quantize_k_cache_separate, ) -from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype, is_fp8_fnuz from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.utils import ( get_mla_kv_buffer_triton, @@ -1575,14 +1575,9 @@ def set_mla_kv_buffer( ): layer_id = layer.layer_id - if ( - _is_hip - and self.use_nsa - and self.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) - ): + if _is_hip and self.use_nsa and self.dtype == fp8_dtype: # HIP FP8 path uses raw MLA KV layout (nope + rope) without per-block scales. # Fuse BF16/FP16 -> FP8 cast with paged KV write. - fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn set_mla_kv_buffer_triton_fp8_quant( self.kv_buffer[layer_id - self.start_layer], loc,