diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index c3613fe837d..6ba68662132 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -58,10 +58,8 @@ ): _enable_jit_deepgemm = True - logger = logging.getLogger(__name__) - if supports_custom_op(): def deep_gemm_fp8_fp8_bf16_nt( @@ -897,16 +895,20 @@ def _per_tensor_quant_mla_fp8_stage2( def per_tensor_quant_mla_fp8( - x: torch.Tensor, eps: float = 1e-12 + x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12 ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function quantizes input values to float8 values with tensor-wise quantization and specialized for mla absorbed case. """ assert x.dim() == 3, "`x` is not a 3d-tensor" + assert ( + x_s_out.shape == (1,) + and x_s_out.dtype == torch.float32 + and x_s_out.device == x.device + ) x_q = x.new_empty(x.size(), dtype=_fp8_type) - x_s = torch.zeros((1,), dtype=torch.float32, device=x.device) num_head, num_seq, head_size = x.shape BLOCK_SIZE = triton.next_power_of_2(head_size) @@ -914,7 +916,7 @@ def per_tensor_quant_mla_fp8( _per_tensor_quant_mla_fp8_stage1[grid]( x, - x_s, + x_s_out, head_size, x.stride(0), x.stride(1), @@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8( ) _per_tensor_quant_mla_fp8_stage2[grid]( x, - x_s, + x_s_out, x_q, num_seq, head_size, @@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8( BLOCK_SIZE, ) - return x_q, x_s + return x_q, x_s_out def scaled_fp8_quant( diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 645b13d7414..b77ad0c9a35 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM -from sglang.srt.utils import add_prefix, is_cuda, is_hip +from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip _is_hip = is_hip() _is_cuda = is_cuda() @@ -91,6 +91,12 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + zero_allocator = BumpAllocator( + buffer_size=2, + dtype=torch.float32, + device=input_ids.device, + ) + if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: @@ -108,7 +114,7 @@ def forward( residual = None hidden_states, residual = self.decoder( - positions, hidden_states, forward_batch, residual + positions, hidden_states, forward_batch, residual, zero_allocator ) if not forward_batch.forward_mode.is_idle(): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 26073bd6769..fffa8cf150d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -76,7 +76,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip +from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip _is_hip = is_hip() _is_cuda = is_cuda() @@ -97,7 +97,6 @@ class AttnForwardMethod(IntEnum): - # Use multi-head attention MHA = auto() @@ -590,6 +589,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, ) -> torch.Tensor: if hidden_states.shape[0] == 0: assert ( @@ -615,9 +615,13 @@ def forward( positions, hidden_states, forward_batch ) else: - return self.forward_absorb(positions, hidden_states, forward_batch) + return self.forward_absorb( + positions, hidden_states, forward_batch, zero_allocator + ) else: - return self.forward_absorb(positions, hidden_states, forward_batch) + return self.forward_absorb( + positions, hidden_states, forward_batch, zero_allocator + ) def forward_normal( self, @@ -666,6 +670,7 @@ def forward_absorb( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, ) -> torch.Tensor: q_len = hidden_states.shape[0] q_input = hidden_states.new_empty( @@ -690,6 +695,7 @@ def forward_absorb( elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope.transpose(0, 1), + zero_allocator.allocate(1), ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 @@ -721,6 +727,7 @@ def forward_absorb( elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output.transpose(0, 1), + zero_allocator.allocate(1), ) attn_bmm_output = bmm_fp8( attn_output_val, @@ -741,6 +748,7 @@ def forward_absorb_fused_mla_rope( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, ) -> torch.Tensor: enable_rope_fusion = ( os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1" @@ -767,7 +775,9 @@ def forward_absorb_fused_mla_rope( ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn + q_nope.transpose(0, 1), + zero_allocator.allocate(1), + dtype=torch.float8_e4m3fn, ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 @@ -863,7 +873,9 @@ def forward_absorb_fused_mla_rope( ) elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn + attn_output.transpose(0, 1), + zero_allocator.allocate(1), + dtype=torch.float8_e4m3fn, ) attn_bmm_output = bmm_fp8( attn_output_val, @@ -1115,14 +1127,15 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, ) -> torch.Tensor: if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: return self.forward_ffn_with_scattered_input( - positions, hidden_states, forward_batch, residual + positions, hidden_states, forward_batch, residual, zero_allocator ) elif self.info.ffn_input_mode == _FFNInputMode.FULL: return self.forward_ffn_with_full_input( - positions, hidden_states, forward_batch, residual + positions, hidden_states, forward_batch, residual, zero_allocator ) else: raise NotImplementedError @@ -1133,6 +1146,7 @@ def forward_ffn_with_full_input( hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, ) -> torch.Tensor: if hidden_states.shape[0] == 0: @@ -1153,6 +1167,7 @@ def forward_ffn_with_full_input( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, + zero_allocator=zero_allocator, ) # Gather @@ -1200,6 +1215,7 @@ def forward_ffn_with_scattered_input( hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, ) -> torch.Tensor: if hidden_states.shape[0] == 0: @@ -1225,6 +1241,7 @@ def forward_ffn_with_scattered_input( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, + zero_allocator=zero_allocator, ) if self.attn_tp_size != 1: @@ -1312,6 +1329,12 @@ def forward( forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + zero_allocator = BumpAllocator( + # TODO for two-batch-overlap, we need a larger buffer size + buffer_size=len(self.layers) * 2, + dtype=torch.float32, + device=input_ids.device, + ) if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -1323,7 +1346,7 @@ def forward( expert_distribution_recorder.set_current_layer(i) layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual + positions, hidden_states, forward_batch, residual, zero_allocator ) if not forward_batch.forward_mode.is_idle(): if residual is None: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a6c2d910bb5..8a191588c46 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config): "MistralForCausalLM", } return architectures[0] in default_archs + + +# Can be more general if it is used in multiple places (keep it simple and thus not general now) +class BumpAllocator: + def __init__(self, buffer_size: int, dtype, device): + self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device) + self._pointer = 0 + + def allocate(self, size: int): + assert self._pointer + size <= len(self._buffer) + output = self._buffer[self._pointer : self._pointer + size] + self._pointer += size + return output