Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,16 @@ def forward_extend(
page_size=1,
)

if self.nsa_prefill_impl == "tilelang":
nsa_impl = (
self.nsa_decode_impl
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
)
else self.nsa_prefill_impl
)

if nsa_impl == "tilelang":
if q_rope is not None:
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_tilelang(
Expand All @@ -1273,7 +1282,7 @@ def forward_extend(
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif self.nsa_prefill_impl == "flashmla_sparse":
elif nsa_impl == "flashmla_sparse":
if q_rope is not None:
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)

Expand All @@ -1297,7 +1306,7 @@ def forward_extend(
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif self.nsa_prefill_impl == "flashmla_kv":
elif nsa_impl == "flashmla_kv":
if q_rope is not None:
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_kv(
Expand All @@ -1310,7 +1319,7 @@ def forward_extend(
metadata=metadata,
page_table_1=page_table_1,
)
elif self.nsa_prefill_impl == "fa3":
elif nsa_impl == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
Expand All @@ -1326,7 +1335,7 @@ def forward_extend(
page_size=1,
)
else:
raise ValueError(f"Unsupported {self.nsa_prefill_impl = }")
raise ValueError(f"Unsupported {nsa_impl = }")

def forward_decode(
self,
Expand Down
19 changes: 12 additions & 7 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def set_torch_compile_config():
monkey_patch_torch_compile()


def get_batch_sizes_to_capture(model_runner: ModelRunner):
def get_batch_sizes_to_capture(model_runner: ModelRunner, num_tokens_per_bs=1):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs

Expand All @@ -199,11 +199,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):

if server_args.enable_two_batch_overlap:
mul_base *= 2
num_tokens_per_bs = 1 # tbo not test, set num_tokens_per_bs to 1

if require_gathered_buffer(server_args):
mul_base *= get_attention_tp_size()

capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
# Model input token count = bs * num_tokens_per_bs; must be a multiple of attn_tp_size.
capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0]
Comment thread
Fridge003 marked this conversation as resolved.

capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
Expand Down Expand Up @@ -267,11 +269,6 @@ def __init__(self, model_runner: ModelRunner):
self.dllm_config = DllmConfig.from_server_args(model_runner.server_args)
self.is_dllm = self.dllm_config is not None

# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
if KTRANSFORMERS_AVAILABLE:
KTMoEWrapper.set_capture_batch_sizes(self.capture_bs)
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
Expand All @@ -291,6 +288,14 @@ def __init__(self, model_runner: ModelRunner):
self.capture_forward_mode = ForwardMode.DLLM_EXTEND
self.num_tokens_per_bs = self.dllm_config.block_size

# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(
model_runner, self.num_tokens_per_bs
)
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
if KTRANSFORMERS_AVAILABLE:
KTMoEWrapper.set_capture_batch_sizes(self.capture_bs)

# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
Expand Down
Loading