From 09afb56c1c6807b2af8c8fb2a54d9964397d9193 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 00:39:32 +0000 Subject: [PATCH 01/18] Fix trtllm mla backend when chunked prefix cache is disabled --- .../sglang/srt/layers/attention/trtllm_mla_backend.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 1882881e5d7f..1f6647977d8a 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -860,6 +860,13 @@ def forward_extend( cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, ) -> torch.Tensor: + + # When chunked prefix cache is disabled, fallback to normal MLA path + if self.disable_chunked_prefix_cache: + return super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( @@ -1003,6 +1010,10 @@ def forward_extend( output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output + # When chunked prefix cache is enabled, dispatch to different path for ragged attention. + assert ( + not self.disable_chunked_prefix_cache + ), "Chunked prefix cache should be enabled when using ragged attention." if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None From 6c06bb5e598717b514fc5fc3186bac374a2e709a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:01:26 +0000 Subject: [PATCH 02/18] upd --- .../srt/layers/attention/trtllm_mla_backend.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 1f6647977d8a..644088b0c891 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -284,6 +284,9 @@ def __init__( self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + # Whether to fallback to flashinfer MLA kernel + self.fallback_to_flashinfer_mla = False + def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -516,7 +519,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - if self.disable_chunked_prefix_cache: + # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + self.fallback_to_flashinfer_mla = ( + self.disable_chunked_prefix_cache and not extend_no_prefix + ) + if self.fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -537,6 +545,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): or forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): + self.fallback_to_flashinfer_mla = False bs = forward_batch.batch_size # Get maximum sequence length. @@ -583,6 +592,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: + self.fallback_to_flashinfer_mla = True return super().init_forward_metadata(forward_batch) def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): @@ -861,8 +871,7 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - # When chunked prefix cache is disabled, fallback to normal MLA path - if self.disable_chunked_prefix_cache: + if self.fallback_to_flashinfer_mla: return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) @@ -1011,9 +1020,6 @@ def forward_extend( return output # When chunked prefix cache is enabled, dispatch to different path for ragged attention. - assert ( - not self.disable_chunked_prefix_cache - ), "Chunked prefix cache should be enabled when using ragged attention." if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None From 75b5c4a1b161ba6f2839021b708f28886ca9a611 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:34:25 +0000 Subject: [PATCH 03/18] fix --- .../srt/layers/attention/trtllm_mla_backend.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 644088b0c891..9fd802941248 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -207,6 +207,7 @@ class TRTLLMMLAPrefillMetadata: max_seq_len: int cum_seq_lens: torch.Tensor seq_lens: torch.Tensor + fallback_to_flashinfer_mla: bool = False @dataclass @@ -284,9 +285,6 @@ def __init__( self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - # Whether to fallback to flashinfer MLA kernel - self.fallback_to_flashinfer_mla = False - def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -521,10 +519,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ): # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - self.fallback_to_flashinfer_mla = ( + fallback_to_flashinfer_mla = ( self.disable_chunked_prefix_cache and not extend_no_prefix ) - if self.fallback_to_flashinfer_mla: + if fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -539,13 +537,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seq_len, cum_seq_lens_q, seq_lens, + fallback_to_flashinfer_mla, ) elif ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - self.fallback_to_flashinfer_mla = False bs = forward_batch.batch_size # Get maximum sequence length. @@ -592,7 +590,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: - self.fallback_to_flashinfer_mla = True return super().init_forward_metadata(forward_batch) def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): @@ -871,7 +868,7 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - if self.fallback_to_flashinfer_mla: + if self.forward_prefill_metadata.fallback_to_flashinfer_mla: return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) From 8024c1e12568eb7b9b22285e495afca5a4f1e34e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:36:20 +0000 Subject: [PATCH 04/18] fix --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9fd802941248..7d5306e1c38f 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -518,9 +518,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. - extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + has_prefix = any(forward_batch.extend_prefix_lens_cpu) fallback_to_flashinfer_mla = ( - self.disable_chunked_prefix_cache and not extend_no_prefix + self.disable_chunked_prefix_cache and has_prefix ) if fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) From 4b74679ba6bbb6ac7c1015f3ac9c195aa1fd6324 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 02:14:33 +0000 Subject: [PATCH 05/18] fix --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7d5306e1c38f..4f80ab46efcc 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -868,7 +868,10 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - if self.forward_prefill_metadata.fallback_to_flashinfer_mla: + if ( + self.forward_prefill_metadata is not None + and self.forward_prefill_metadata.fallback_to_flashinfer_mla + ): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) From 16c5640d92793bcdfcd26c597b30dd6e63632055 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 5 Nov 2025 13:16:49 -0800 Subject: [PATCH 06/18] upd --- .../srt/layers/attention/trtllm_mla_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7dcf98adef79..6fdbcefd61b6 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -207,7 +207,7 @@ class TRTLLMMLAPrefillMetadata: max_seq_len: int cum_seq_lens: torch.Tensor seq_lens: torch.Tensor - fallback_to_flashinfer_mla: bool = False + fallback_to_flashinfer_impl: bool = False @dataclass @@ -552,12 +552,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. + # For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend + # when chunked prefix cache is disabled. has_prefix = any(forward_batch.extend_prefix_lens_cpu) - fallback_to_flashinfer_mla = ( + fallback_to_flashinfer_impl = ( self.disable_chunked_prefix_cache and has_prefix ) - if fallback_to_flashinfer_mla: + if fallback_to_flashinfer_impl: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -572,7 +573,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seq_len, cum_seq_lens_q, seq_lens, - fallback_to_flashinfer_mla, + fallback_to_flashinfer_impl, ) elif ( forward_batch.forward_mode.is_decode_or_idle() @@ -907,7 +908,7 @@ def forward_extend( if ( self.forward_prefill_metadata is not None - and self.forward_prefill_metadata.fallback_to_flashinfer_mla + and self.forward_prefill_metadata.fallback_to_flashinfer_impl ): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope From 311d143c238402f8fafeb882146ee2b1b1631938 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Wed, 5 Nov 2025 21:40:56 -0800 Subject: [PATCH 07/18] iter --- python/sglang/srt/layers/quantization/modelopt_quant.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 99ae27684f2e..0a9e9713bb64 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1451,7 +1451,12 @@ def _slice_scale(w): w2_input_scale = _slice_scale(w2_input_scale) if CUTEDSL_MOE_NVFP4_DISPATCH: - assert torch.all(w13_input_scale == w13_input_scale[0]) + print(f"w13_input_scale shape: {w13_input_scale.shape}") + print(f"w13_input_scale values: {w13_input_scale}") + print(f"Unique values: {torch.unique(w13_input_scale)}") + print(f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}") + #assert torch.all(w13_input_scale == w13_input_scale[0]) + logger.warning("ISHAN: removed assert so this might break down the line lets see...") w13_input_scale = w13_input_scale[0] else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) From 920344f4c78f901c715577cbd0b3359474563e5a Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Wed, 5 Nov 2025 22:03:31 -0800 Subject: [PATCH 08/18] ``` refactor(quantization): replace print statements with logger warnings ``` --- python/sglang/srt/layers/quantization/modelopt_quant.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 0a9e9713bb64..948a0510d204 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1451,10 +1451,10 @@ def _slice_scale(w): w2_input_scale = _slice_scale(w2_input_scale) if CUTEDSL_MOE_NVFP4_DISPATCH: - print(f"w13_input_scale shape: {w13_input_scale.shape}") - print(f"w13_input_scale values: {w13_input_scale}") - print(f"Unique values: {torch.unique(w13_input_scale)}") - print(f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}") + logger.warning(f"w13_input_scale shape: {w13_input_scale.shape}") + logger.warning(f"w13_input_scale values: {w13_input_scale}") + logger.warning(f"Unique values: {torch.unique(w13_input_scale)}") + logger.warning(f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}") #assert torch.all(w13_input_scale == w13_input_scale[0]) logger.warning("ISHAN: removed assert so this might break down the line lets see...") w13_input_scale = w13_input_scale[0] From 7483bc6a7aa9ef5bb61017c095ab32b1d1cbe328 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Wed, 5 Nov 2025 22:15:01 -0800 Subject: [PATCH 09/18] go --- python/sglang/srt/layers/quantization/modelopt_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 948a0510d204..cb4548856215 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1455,8 +1455,8 @@ def _slice_scale(w): logger.warning(f"w13_input_scale values: {w13_input_scale}") logger.warning(f"Unique values: {torch.unique(w13_input_scale)}") logger.warning(f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}") - #assert torch.all(w13_input_scale == w13_input_scale[0]) - logger.warning("ISHAN: removed assert so this might break down the line lets see...") + if not torch.all(w13_input_scale == w13_input_scale[0]): + logger.warning("This would have triggered an assert: w13_input_scale is not constant across experts, but continuing anyway.") w13_input_scale = w13_input_scale[0] else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) From 09b881db728f6cd5941ee565d76c5af54b47573a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Nov 2025 19:57:03 +0000 Subject: [PATCH 10/18] print logs --- python/sglang/srt/eplb/expert_location.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 4dff84152f11..ac7ad675523b 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -93,7 +93,7 @@ def init_trivial( if common is None: return None - + print(f"init_trivial: common={common}") num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers @@ -103,7 +103,7 @@ def init_trivial( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) - + print(f"init_trivial: physical_to_logical_map={physical_to_logical_map}") return ExpertLocationMetadata.init_by_mapping( server_args, model_config, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 963e2cd814c6..daa77ce84756 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -358,6 +358,9 @@ def initialize(self, min_per_gpu_memory: float): ) if not self.is_draft_worker: + print( + f"compute_initial_expert_location_metadata: ep_rank={self.moe_ep_rank}" + ) set_global_expert_location_metadata( compute_initial_expert_location_metadata( server_args=server_args, From 34c400ee27f94578fe67f3c6978bbff0170c6580 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Nov 2025 13:00:54 -0800 Subject: [PATCH 11/18] Revert "print logs" This reverts commit 09b881db728f6cd5941ee565d76c5af54b47573a. --- python/sglang/srt/eplb/expert_location.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index ac7ad675523b..4dff84152f11 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -93,7 +93,7 @@ def init_trivial( if common is None: return None - print(f"init_trivial: common={common}") + num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers @@ -103,7 +103,7 @@ def init_trivial( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) - print(f"init_trivial: physical_to_logical_map={physical_to_logical_map}") + return ExpertLocationMetadata.init_by_mapping( server_args, model_config, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index daa77ce84756..963e2cd814c6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -358,9 +358,6 @@ def initialize(self, min_per_gpu_memory: float): ) if not self.is_draft_worker: - print( - f"compute_initial_expert_location_metadata: ep_rank={self.moe_ep_rank}" - ) set_global_expert_location_metadata( compute_initial_expert_location_metadata( server_args=server_args, From 3743e60f3b76259a2d4b2b04501bd86ad3794229 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Nov 2025 13:19:17 -0800 Subject: [PATCH 12/18] add back assert --- python/sglang/srt/layers/quantization/modelopt_quant.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index cb4548856215..db2b004eb587 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1454,9 +1454,10 @@ def _slice_scale(w): logger.warning(f"w13_input_scale shape: {w13_input_scale.shape}") logger.warning(f"w13_input_scale values: {w13_input_scale}") logger.warning(f"Unique values: {torch.unique(w13_input_scale)}") - logger.warning(f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}") - if not torch.all(w13_input_scale == w13_input_scale[0]): - logger.warning("This would have triggered an assert: w13_input_scale is not constant across experts, but continuing anyway.") + logger.warning( + f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}" + ) + assert torch.all(w13_input_scale == w13_input_scale[0]) w13_input_scale = w13_input_scale[0] else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) From 124941a73b895550aa7ce480da41ac0e4a71c902 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Nov 2025 19:57:03 +0000 Subject: [PATCH 13/18] print logs --- python/sglang/srt/eplb/expert_location.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 4dff84152f11..ac7ad675523b 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -93,7 +93,7 @@ def init_trivial( if common is None: return None - + print(f"init_trivial: common={common}") num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers @@ -103,7 +103,7 @@ def init_trivial( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) - + print(f"init_trivial: physical_to_logical_map={physical_to_logical_map}") return ExpertLocationMetadata.init_by_mapping( server_args, model_config, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 963e2cd814c6..daa77ce84756 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -358,6 +358,9 @@ def initialize(self, min_per_gpu_memory: float): ) if not self.is_draft_worker: + print( + f"compute_initial_expert_location_metadata: ep_rank={self.moe_ep_rank}" + ) set_global_expert_location_metadata( compute_initial_expert_location_metadata( server_args=server_args, From 8ea3cd5fa410c97f03a60a72988a15b3564797d5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Nov 2025 13:26:07 -0800 Subject: [PATCH 14/18] more logs --- python/sglang/srt/eplb/expert_location.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index ac7ad675523b..7a3450409275 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -128,6 +128,7 @@ def init_by_mapping( return None model_config_for_expert_location = common["model_config_for_expert_location"] + print(f"common={common}") logical_to_all_physical_map = _compute_logical_to_all_physical_map( server_args=server_args, physical_to_logical_map=physical_to_logical_map, @@ -135,7 +136,9 @@ def init_by_mapping( ep_size=common["ep_size"], moe_ep_rank=moe_ep_rank, ) - + print( + f"physical_to_logical_map={physical_to_logical_map}, logical_to_all_physical_map={logical_to_all_physical_map}" + ) return ExpertLocationMetadata._init_raw( server_args=server_args, ep_size=common["ep_size"], From 93b6c07848f5b37a4ff9a1fe3e9538b2cd87c375 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 8 Nov 2025 02:47:52 +0000 Subject: [PATCH 15/18] more logs --- python/sglang/srt/eplb/expert_location.py | 34 ++++++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 7a3450409275..315947572c5b 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -103,7 +103,9 @@ def init_trivial( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) - print(f"init_trivial: physical_to_logical_map={physical_to_logical_map}") + print( + f"init_trivial: physical_to_logical_map={physical_to_logical_map.shape} {physical_to_logical_map[0]}" + ) return ExpertLocationMetadata.init_by_mapping( server_args, model_config, @@ -136,9 +138,16 @@ def init_by_mapping( ep_size=common["ep_size"], moe_ep_rank=moe_ep_rank, ) - print( - f"physical_to_logical_map={physical_to_logical_map}, logical_to_all_physical_map={logical_to_all_physical_map}" - ) + if moe_ep_rank % 4 == 0: + print( + f"physical_to_logical_map={physical_to_logical_map.shape}, logical_to_all_physical_map={logical_to_all_physical_map.shape}" + ) + for i in range(physical_to_logical_map.shape[0]): + print(f"physical_to_logical_map[{i}]={physical_to_logical_map[i]}") + for i in range(logical_to_all_physical_map.shape[0]): + print( + f"logical_to_all_physical_map[{i}]={logical_to_all_physical_map[i]}" + ) return ExpertLocationMetadata._init_raw( server_args=server_args, ep_size=common["ep_size"], @@ -233,6 +242,23 @@ def _init_raw( logical_to_all_physical_map != -1, dim=-1 ) + print( + f"logical_to_all_physical_map_num_valid={logical_to_all_physical_map_num_valid.shape} {logical_to_all_physical_map_num_valid[0]}" + ) + print( + f"logical_to_all_physical_map_padded={logical_to_all_physical_map_padded.shape} {logical_to_all_physical_map_padded[0]}" + ) + dispatch_physical_map = compute_logical_to_rank_dispatch_physical_map( + server_args=server_args, + logical_to_all_physical_map=logical_to_all_physical_map, + ep_size=ep_size, + num_physical_experts=num_physical_experts, + # TODO improve when we have real EP rank + ep_rank=torch.distributed.get_rank() % ep_size, + ) + print( + f"dispatch_physical_map={dispatch_physical_map.shape} {dispatch_physical_map}" + ) return ExpertLocationMetadata( physical_to_logical_map=physical_to_logical_map, physical_to_logical_map_cpu=physical_to_logical_map.cpu(), From 05ffb025aec92cf91ca38df2a51cf103dc5424a7 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 8 Nov 2025 03:13:04 +0000 Subject: [PATCH 16/18] more log --- python/sglang/srt/eplb/expert_location.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 315947572c5b..b321cec9dee9 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -142,12 +142,8 @@ def init_by_mapping( print( f"physical_to_logical_map={physical_to_logical_map.shape}, logical_to_all_physical_map={logical_to_all_physical_map.shape}" ) - for i in range(physical_to_logical_map.shape[0]): - print(f"physical_to_logical_map[{i}]={physical_to_logical_map[i]}") - for i in range(logical_to_all_physical_map.shape[0]): - print( - f"logical_to_all_physical_map[{i}]={logical_to_all_physical_map[i]}" - ) + print(f"physical_to_logical_map={physical_to_logical_map[0]}") + print(f"logical_to_all_physical_map={logical_to_all_physical_map[0]}") return ExpertLocationMetadata._init_raw( server_args=server_args, ep_size=common["ep_size"], @@ -362,7 +358,7 @@ def _compute_logical_to_all_physical_map( logical_to_all_physical_map[layer_id][logical_expert_id].append( physical_expert_id ) - + print(f"Preview logical_to_all_physical_map={logical_to_all_physical_map[0]}") # Replace by the physical expert on local GPU or node if possible if moe_ep_rank is not None: num_gpus_per_node = server_args.ep_size // server_args.nnodes From ba2c77deb95937ee03898c8f42c981fb6009b8d9 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 8 Nov 2025 03:43:32 +0000 Subject: [PATCH 17/18] upd --- python/sglang/srt/eplb/expert_location.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index b321cec9dee9..f45b7feb31ce 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -378,12 +378,19 @@ def _compute_logical_to_all_physical_map( num_gpus_per_node=num_gpus_per_node, num_local_node_physical_experts=num_local_node_physical_experts, ) + print( + f"layer_id={layer_id}, logical_expert_id={logical_expert_id}, nearest_expert={nearest_expert}" + ) # Replace by the nearest physical expert - if nearest_expert != -1: - logical_to_all_physical_map[layer_id][logical_expert_id] = [ - nearest_expert - ] + mapped_phsical_experts = logical_to_all_physical_map[layer_id][ + logical_expert_id + ] + if ( + nearest_expert != -1 + and nearest_expert not in mapped_phsical_experts + ): + mapped_phsical_experts[0] = nearest_expert logical_to_all_physical_map = _pad_nested_array( logical_to_all_physical_map, pad_value=-1 From 8229829d59c51654a9f1d21d7fc0aa0575952784 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 8 Nov 2025 04:14:41 +0000 Subject: [PATCH 18/18] remove logging --- python/sglang/srt/eplb/expert_location.py | 33 ------------------- .../srt/layers/quantization/modelopt_quant.py | 6 ---- .../sglang/srt/model_executor/model_runner.py | 3 -- 3 files changed, 42 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index f45b7feb31ce..862d54c957b8 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -93,7 +93,6 @@ def init_trivial( if common is None: return None - print(f"init_trivial: common={common}") num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers @@ -103,9 +102,6 @@ def init_trivial( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) - print( - f"init_trivial: physical_to_logical_map={physical_to_logical_map.shape} {physical_to_logical_map[0]}" - ) return ExpertLocationMetadata.init_by_mapping( server_args, model_config, @@ -130,7 +126,6 @@ def init_by_mapping( return None model_config_for_expert_location = common["model_config_for_expert_location"] - print(f"common={common}") logical_to_all_physical_map = _compute_logical_to_all_physical_map( server_args=server_args, physical_to_logical_map=physical_to_logical_map, @@ -138,12 +133,6 @@ def init_by_mapping( ep_size=common["ep_size"], moe_ep_rank=moe_ep_rank, ) - if moe_ep_rank % 4 == 0: - print( - f"physical_to_logical_map={physical_to_logical_map.shape}, logical_to_all_physical_map={logical_to_all_physical_map.shape}" - ) - print(f"physical_to_logical_map={physical_to_logical_map[0]}") - print(f"logical_to_all_physical_map={logical_to_all_physical_map[0]}") return ExpertLocationMetadata._init_raw( server_args=server_args, ep_size=common["ep_size"], @@ -237,24 +226,6 @@ def _init_raw( logical_to_all_physical_map_num_valid = torch.count_nonzero( logical_to_all_physical_map != -1, dim=-1 ) - - print( - f"logical_to_all_physical_map_num_valid={logical_to_all_physical_map_num_valid.shape} {logical_to_all_physical_map_num_valid[0]}" - ) - print( - f"logical_to_all_physical_map_padded={logical_to_all_physical_map_padded.shape} {logical_to_all_physical_map_padded[0]}" - ) - dispatch_physical_map = compute_logical_to_rank_dispatch_physical_map( - server_args=server_args, - logical_to_all_physical_map=logical_to_all_physical_map, - ep_size=ep_size, - num_physical_experts=num_physical_experts, - # TODO improve when we have real EP rank - ep_rank=torch.distributed.get_rank() % ep_size, - ) - print( - f"dispatch_physical_map={dispatch_physical_map.shape} {dispatch_physical_map}" - ) return ExpertLocationMetadata( physical_to_logical_map=physical_to_logical_map, physical_to_logical_map_cpu=physical_to_logical_map.cpu(), @@ -358,7 +329,6 @@ def _compute_logical_to_all_physical_map( logical_to_all_physical_map[layer_id][logical_expert_id].append( physical_expert_id ) - print(f"Preview logical_to_all_physical_map={logical_to_all_physical_map[0]}") # Replace by the physical expert on local GPU or node if possible if moe_ep_rank is not None: num_gpus_per_node = server_args.ep_size // server_args.nnodes @@ -378,9 +348,6 @@ def _compute_logical_to_all_physical_map( num_gpus_per_node=num_gpus_per_node, num_local_node_physical_experts=num_local_node_physical_experts, ) - print( - f"layer_id={layer_id}, logical_expert_id={logical_expert_id}, nearest_expert={nearest_expert}" - ) # Replace by the nearest physical expert mapped_phsical_experts = logical_to_all_physical_map[layer_id][ diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index db2b004eb587..99ae27684f2e 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1451,12 +1451,6 @@ def _slice_scale(w): w2_input_scale = _slice_scale(w2_input_scale) if CUTEDSL_MOE_NVFP4_DISPATCH: - logger.warning(f"w13_input_scale shape: {w13_input_scale.shape}") - logger.warning(f"w13_input_scale values: {w13_input_scale}") - logger.warning(f"Unique values: {torch.unique(w13_input_scale)}") - logger.warning( - f"All equal? {torch.all(w13_input_scale == w13_input_scale[0])}" - ) assert torch.all(w13_input_scale == w13_input_scale[0]) w13_input_scale = w13_input_scale[0] else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index daa77ce84756..963e2cd814c6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -358,9 +358,6 @@ def initialize(self, min_per_gpu_memory: float): ) if not self.is_draft_worker: - print( - f"compute_initial_expert_location_metadata: ep_rank={self.moe_ep_rank}" - ) set_global_expert_location_metadata( compute_initial_expert_location_metadata( server_args=server_args,