diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bb4ef3ad0b56..2b4d9f65e990 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -49,7 +49,7 @@ set_mla_kv_buffer_triton, set_mla_kv_scale_buffer_triton, ) -from sglang.srt.utils import is_cuda, is_float4_e2m1fn_x2, is_npu, next_power_of_2 +from sglang.srt.utils import is_cuda, is_npu, next_power_of_2 if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter @@ -611,65 +611,24 @@ def _create_buffers(self): if self.enable_custom_mem_pool else nullcontext() ): - if is_float4_e2m1fn_x2(self.dtype): - m = self.size + self.page_size - n = self.head_num - k = self.head_dim - - scale_block_size = 16 - self.store_dtype = torch.uint8 - self.k_buffer = [ - torch.zeros( - (m, n, k // 2), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.zeros( - (m, n, k // 2), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - - self.k_scale_buffer = [ - torch.zeros( - (m, (n * k) // scale_block_size), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_scale_buffer = [ - torch.zeros( - (m, (n * k) // scale_block_size), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - else: - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.zeros( - (self.size + self.page_size, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.zeros( - (self.size + self.page_size, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + (self.size + self.page_size, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] self.k_data_ptrs = torch.tensor( [x.data_ptr() for x in self.k_buffer], @@ -770,22 +729,7 @@ def load_cpu_copy(self, kv_cache_cpu, indices): def _get_key_buffer(self, layer_id: int): # for internal use of referencing if self.store_dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - cache_k_nope_fp4 = self.k_buffer[layer_id - self.start_layer].view( - torch.uint8 - ) - cache_k_nope_fp4_sf = self.k_scale_buffer[layer_id - self.start_layer] - - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( - cache_k_nope_fp4, cache_k_nope_fp4_sf - ) - return cache_k_nope_fp4_dequant - else: - return self.k_buffer[layer_id - self.start_layer].view(self.dtype) + return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer] def get_key_buffer(self, layer_id: int): @@ -799,22 +743,7 @@ def get_key_buffer(self, layer_id: int): def _get_value_buffer(self, layer_id: int): # for internal use of referencing if self.store_dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - cache_v_nope_fp4 = self.v_buffer[layer_id - self.start_layer].view( - torch.uint8 - ) - cache_v_nope_fp4_sf = self.v_scale_buffer[layer_id - self.start_layer] - - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_v_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( - cache_v_nope_fp4, cache_v_nope_fp4_sf - ) - return cache_v_nope_fp4_dequant - else: - return self.v_buffer[layer_id - self.start_layer].view(self.dtype) + return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): @@ -846,44 +775,24 @@ def set_kv_buffer( cache_k.div_(k_scale) if v_scale is not None: cache_v.div_(v_scale) - if is_float4_e2m1fn_x2(self.dtype): - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_k, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_k) - cache_v, cache_v_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_v) - else: - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - if is_float4_e2m1fn_x2(self.dtype): - cache_k_fp4_sf = cache_k_fp4_sf.view(self.store_dtype) - cache_v_fp4_sf = cache_v_fp4_sf.view(self.store_dtype) if get_is_capture_mode() and self.alt_stream is not None: # Overlap the copy of K and V cache for small batch size current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) self.k_buffer[layer_id - self.start_layer][loc] = cache_k - if is_float4_e2m1fn_x2(self.dtype): - self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf with self.device_module.stream(self.alt_stream): self.v_buffer[layer_id - self.start_layer][loc] = cache_v - if is_float4_e2m1fn_x2(self.dtype): - self.v_scale_buffer[layer_id - self.start_layer][ - loc - ] = cache_v_fp4_sf current_stream.wait_stream(self.alt_stream) else: self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.v_buffer[layer_id - self.start_layer][loc] = cache_v - if is_float4_e2m1fn_x2(self.dtype): - self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf - self.v_scale_buffer[layer_id - self.start_layer][loc] = cache_v_fp4_sf def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): N = tgt_loc.numel() @@ -911,6 +820,149 @@ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): ) +class MHATokenToKVPoolFP4(MHATokenToKVPool): + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + m = self.size + self.page_size + n = self.head_num + k = self.head_dim + + scale_block_size = 16 + self.store_dtype = torch.uint8 + self.k_buffer = [ + torch.zeros( + (m, n, k // 2), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + (m, n, k // 2), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + self.k_scale_buffer = [ + torch.zeros( + (m, (n * k) // scale_block_size), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_scale_buffer = [ + torch.zeros( + (m, (n * k) // scale_block_size), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def _clear_buffers(self): + del self.k_buffer + del self.v_buffer + del self.k_scale_buffer + del self.v_scale_buffer + + def _get_key_buffer(self, layer_id: int): + # for internal use of referencing + if self.store_dtype != self.dtype: + cache_k_nope_fp4 = self.k_buffer[layer_id - self.start_layer].view( + torch.uint8 + ) + cache_k_nope_fp4_sf = self.k_scale_buffer[layer_id - self.start_layer] + + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( + cache_k_nope_fp4, cache_k_nope_fp4_sf + ) + return cache_k_nope_fp4_dequant + return self.k_buffer[layer_id - self.start_layer] + + def _get_value_buffer(self, layer_id: int): + # for internal use of referencing + if self.store_dtype != self.dtype: + cache_v_nope_fp4 = self.v_buffer[layer_id - self.start_layer].view( + torch.uint8 + ) + cache_v_nope_fp4_sf = self.v_scale_buffer[layer_id - self.start_layer] + + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + cache_v_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( + cache_v_nope_fp4, cache_v_nope_fp4_sf + ) + return cache_v_nope_fp4_dequant + return self.v_buffer[layer_id - self.start_layer] + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + layer_id_override: Optional[int] = None, + ): + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if layer_id_override is not None: + layer_id = layer_id_override + else: + layer_id = layer.layer_id + if cache_k.dtype != self.dtype: + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) + + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + cache_k, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_k) + cache_v, cache_v_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_v) + + if self.store_dtype != self.dtype: + cache_k = cache_k.view(self.store_dtype) + cache_v = cache_v.view(self.store_dtype) + + cache_k_fp4_sf = cache_k_fp4_sf.view(self.store_dtype) + cache_v_fp4_sf = cache_v_fp4_sf.view(self.store_dtype) + + if get_is_capture_mode() and self.alt_stream is not None: + # Overlap the copy of K and V cache for small batch size + current_stream = self.device_module.current_stream() + self.alt_stream.wait_stream(current_stream) + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + + self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf + with self.device_module.stream(self.alt_stream): + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + + self.v_scale_buffer[layer_id - self.start_layer][loc] = cache_v_fp4_sf + current_stream.wait_stream(self.alt_stream) + else: + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + + self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf + self.v_scale_buffer[layer_id - self.start_layer][loc] = cache_v_fp4_sf + + class HybridLinearKVPool(KVCache): """KV cache with separate pools for full and linear attention layers.""" @@ -1362,47 +1414,7 @@ def __init__( else (kv_lora_rank + qk_rope_head_dim) ) - with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): - with ( - torch.cuda.use_mem_pool(self.custom_mem_pool) - if self.custom_mem_pool - else nullcontext() - ): - if is_float4_e2m1fn_x2(self.dtype): - m = size + page_size - n = 1 # head_num - k = self.kv_cache_dim # head_dim - - scale_block_size = 16 - self.store_dtype = torch.uint8 - - self.kv_buffer = [ - torch.zeros( - (m, n, k // 2), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] - - self.kv_scale_buffer = [ - torch.zeros( - (m, k // scale_block_size), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] - else: - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.zeros( - (size + page_size, 1, self.kv_cache_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + self._create_buffers() self.data_ptrs = torch.tensor( [x.data_ptr() for x in self.kv_buffer], @@ -1413,6 +1425,26 @@ def __init__( # NSA will allocate indexer KV cache later and then log the total size self._finalize_allocation_log(size) + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.zeros( + (self.size + self.page_size, 1, self.kv_cache_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def _clear_buffers(self): + del self.kv_buffer + def get_kv_size_bytes(self): assert hasattr(self, "kv_buffer") kv_size_bytes = 0 @@ -1435,22 +1467,7 @@ def get_key_buffer(self, layer_id: int): self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - cache_k_nope_fp4 = self.kv_buffer[layer_id - self.start_layer].view( - torch.uint8 - ) - cache_k_nope_fp4_sf = self.kv_scale_buffer[layer_id - self.start_layer] - - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( - cache_k_nope_fp4, cache_k_nope_fp4_sf - ) - return cache_k_nope_fp4_dequant - else: - return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) + return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer] @@ -1477,29 +1494,12 @@ def set_kv_buffer( layer_id = layer.layer_id assert not (self.use_nsa and self.nsa_kv_cache_store_fp8) if cache_k.dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_k_fp4, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize( - cache_k - ) - else: - cache_k = cache_k.to(self.dtype) + cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - self.kv_buffer[layer_id - self.start_layer][loc] = cache_k_fp4.view( - self.store_dtype - ) - self.kv_scale_buffer[layer_id - self.start_layer][loc] = ( - cache_k_fp4_sf.view(self.store_dtype) - ) - else: - self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view( - self.store_dtype - ) + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view( + self.store_dtype + ) else: self.kv_buffer[layer_id - self.start_layer][loc] = cache_k @@ -1521,44 +1521,18 @@ def set_mla_kv_buffer( self.kv_buffer[layer_id - self.start_layer][loc] = cache_k else: if cache_k_nope.dtype != self.dtype: - if is_float4_e2m1fn_x2(self.dtype): - from sglang.srt.layers.quantization.kvfp4_tensor import ( - KVFP4QuantizeUtil, - ) - - cache_k_nope_fp4, cache_k_nope_fp4_sf = ( - KVFP4QuantizeUtil.batched_quantize(cache_k_nope) - ) - cache_k_rope_fp4, cache_k_rope_fp4_sf = ( - KVFP4QuantizeUtil.batched_quantize(cache_k_rope) - ) - else: - cache_k_nope = cache_k_nope.to(self.dtype) - cache_k_rope = cache_k_rope.to(self.dtype) + cache_k_nope = cache_k_nope.to(self.dtype) + cache_k_rope = cache_k_rope.to(self.dtype) if self.store_dtype != self.dtype: cache_k_nope = cache_k_nope.view(self.store_dtype) cache_k_rope = cache_k_rope.view(self.store_dtype) - if is_float4_e2m1fn_x2(self.dtype): - set_mla_kv_buffer_triton( - self.kv_buffer[layer_id - self.start_layer], - loc, - cache_k_nope_fp4, - cache_k_rope_fp4, - ) - set_mla_kv_scale_buffer_triton( - self.kv_scale_buffer[layer_id - self.start_layer], - loc, - cache_k_nope_fp4_sf, - cache_k_rope_fp4_sf, - ) - else: - set_mla_kv_buffer_triton( - self.kv_buffer[layer_id - self.start_layer], - loc, - cache_k_nope, - cache_k_rope, - ) + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope, + cache_k_rope, + ) def get_mla_kv_buffer( self, @@ -1611,6 +1585,135 @@ def load_cpu_copy(self, kv_cache_cpu, indices): torch.cuda.synchronize() +class MLATokenToKVPoolFP4(MLATokenToKVPool): + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + m = self.size + self.page_size + n = 1 # head_num + k = self.kv_cache_dim # head_dim + + scale_block_size = 16 + self.store_dtype = torch.uint8 + + self.kv_buffer = [ + torch.zeros( + (m, n, k // 2), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + self.kv_scale_buffer = [ + torch.zeros( + (m, k // scale_block_size), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def _clear_buffers(self): + del self.kv_buffer + del self.kv_scale_buffer + + def get_key_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) + + if self.store_dtype != self.dtype: + cache_k_nope_fp4 = self.kv_buffer[layer_id - self.start_layer].view( + torch.uint8 + ) + cache_k_nope_fp4_sf = self.kv_scale_buffer[layer_id - self.start_layer] + + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( + cache_k_nope_fp4, cache_k_nope_fp4_sf + ) + return cache_k_nope_fp4_dequant + + return self.kv_buffer[layer_id - self.start_layer] + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + layer_id = layer.layer_id + assert not (self.use_nsa and self.nsa_kv_cache_store_fp8) + if cache_k.dtype != self.dtype: + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + cache_k_fp4, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_k) + + if self.store_dtype != self.dtype: + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k_fp4.view( + self.store_dtype + ) + self.kv_scale_buffer[layer_id - self.start_layer][loc] = ( + cache_k_fp4_sf.view(self.store_dtype) + ) + else: + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k + + def set_mla_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k_nope: torch.Tensor, + cache_k_rope: torch.Tensor, + ): + layer_id = layer.layer_id + + if self.use_nsa and self.nsa_kv_cache_store_fp8: + # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here + # TODO no need to cat + cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1) + cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1) + cache_k = cache_k.view(self.store_dtype) + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k + else: + if cache_k_nope.dtype != self.dtype: + from sglang.srt.layers.quantization.kvfp4_tensor import ( + KVFP4QuantizeUtil, + ) + + cache_k_nope_fp4, cache_k_nope_fp4_sf = ( + KVFP4QuantizeUtil.batched_quantize(cache_k_nope) + ) + cache_k_rope_fp4, cache_k_rope_fp4_sf = ( + KVFP4QuantizeUtil.batched_quantize(cache_k_rope) + ) + + if self.store_dtype != self.dtype: + cache_k_nope = cache_k_nope.view(self.store_dtype) + cache_k_rope = cache_k_rope.view(self.store_dtype) + + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id - self.start_layer], + loc, + cache_k_nope_fp4, + cache_k_rope_fp4, + ) + set_mla_kv_scale_buffer_triton( + self.kv_scale_buffer[layer_id - self.start_layer], + loc, + cache_k_nope_fp4_sf, + cache_k_rope_fp4_sf, + ) + + class NSATokenToKVPool(MLATokenToKVPool): quant_block_size = 128 index_k_with_scale_buffer_dtype = torch.uint8 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c8dc396eca9c..b1dabf3fce50 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -111,7 +111,9 @@ HybridLinearKVPool, HybridReqToTokenPool, MHATokenToKVPool, + MHATokenToKVPoolFP4, MLATokenToKVPool, + MLATokenToKVPoolFP4, NSATokenToKVPool, ReqToTokenPool, SWAKVPool, @@ -1802,18 +1804,32 @@ def init_memory_pool( ) elif self.use_mla_backend and not self.mambaish_config: assert not is_nsa_model - self.token_to_kv_pool = MLATokenToKVPool( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - kv_lora_rank=self.model_config.kv_lora_rank, - qk_rope_head_dim=self.model_config.qk_rope_head_dim, - layer_num=self.num_effective_layers, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - start_layer=self.start_layer, - end_layer=self.end_layer, - ) + if is_float4_e2m1fn_x2(self.kv_cache_dtype): + self.token_to_kv_pool = MLATokenToKVPoolFP4( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + else: + self.token_to_kv_pool = MLATokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, @@ -1870,24 +1886,44 @@ def init_memory_pool( **extra_args, ) else: - self.token_to_kv_pool = MHATokenToKVPool( - self.max_total_num_tokens, - page_size=self.page_size, - dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads( - get_attention_tp_size() - ), - head_dim=self.model_config.head_dim, - layer_num=self.num_effective_layers, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - start_layer=self.start_layer, - end_layer=self.end_layer, - enable_alt_stream=not self.server_args.enable_pdmux, - enable_kv_cache_copy=( - self.server_args.speculative_algorithm is not None - ), - ) + if is_float4_e2m1fn_x2(self.kv_cache_dtype): + self.token_to_kv_pool = MHATokenToKVPoolFP4( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + enable_alt_stream=not self.server_args.enable_pdmux, + enable_kv_cache_copy=( + self.server_args.speculative_algorithm is not None + ), + ) + else: + self.token_to_kv_pool = MHATokenToKVPool( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + enable_alt_stream=not self.server_args.enable_pdmux, + enable_kv_cache_copy=( + self.server_args.speculative_algorithm is not None + ), + ) # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")