diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index eb7d8cd0a8c1..8d0b00a136d8 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1263,7 +1263,16 @@ def forward_extend( page_size=1, ) - if self.nsa_prefill_impl == "tilelang": + nsa_impl = ( + self.nsa_decode_impl + if ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ) + else self.nsa_prefill_impl + ) + + if nsa_impl == "tilelang": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( @@ -1273,7 +1282,7 @@ def forward_extend( sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif self.nsa_prefill_impl == "flashmla_sparse": + elif nsa_impl == "flashmla_sparse": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) @@ -1297,7 +1306,7 @@ def forward_extend( sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif self.nsa_prefill_impl == "flashmla_kv": + elif nsa_impl == "flashmla_kv": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( @@ -1310,7 +1319,7 @@ def forward_extend( metadata=metadata, page_table_1=page_table_1, ) - elif self.nsa_prefill_impl == "fa3": + elif nsa_impl == "fa3": return self._forward_fa3( q_rope=q_rope, kv_cache=kv_cache, @@ -1326,7 +1335,7 @@ def forward_extend( page_size=1, ) else: - raise ValueError(f"Unsupported {self.nsa_prefill_impl = }") + raise ValueError(f"Unsupported {nsa_impl = }") def forward_decode( self, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6be1b11c68a8..55a23e333d19 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -186,7 +186,7 @@ def set_torch_compile_config(): monkey_patch_torch_compile() -def get_batch_sizes_to_capture(model_runner: ModelRunner): +def get_batch_sizes_to_capture(model_runner: ModelRunner, num_tokens_per_bs=1): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs @@ -199,11 +199,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if server_args.enable_two_batch_overlap: mul_base *= 2 + num_tokens_per_bs = 1 # tbo not test, set num_tokens_per_bs to 1 if require_gathered_buffer(server_args): mul_base *= get_attention_tp_size() - capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] + # Model input token count = bs * num_tokens_per_bs; must be a multiple of attn_tp_size. + capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0] capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] capture_bs = list(sorted(set(capture_bs))) @@ -267,11 +269,6 @@ def __init__(self, model_runner: ModelRunner): self.dllm_config = DllmConfig.from_server_args(model_runner.server_args) self.is_dllm = self.dllm_config is not None - # Batch sizes to capture - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") - if KTRANSFORMERS_AVAILABLE: - KTMoEWrapper.set_capture_batch_sizes(self.capture_bs) self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 @@ -291,6 +288,14 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture( + model_runner, self.num_tokens_per_bs + ) + log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") + if KTRANSFORMERS_AVAILABLE: + KTMoEWrapper.set_capture_batch_sizes(self.capture_bs) + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup if model_runner.server_args.enable_return_hidden_states: self.capture_hidden_mode = CaptureHiddenMode.FULL