diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 771c4d9dbdf7..e5168498daff 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -15,6 +15,7 @@ import dataclasses import logging +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -67,6 +68,27 @@ _is_npu = is_npu() _is_cpu = is_cpu() +# When set, LogitsProcessor.forward returns an empty output and skips the +# LM head + tensor-parallel all-gather. FlashInfer autotune only profiles +# attention/MoE/GEMM kernels, so the LM-head all-gather is wasted work -- +# and its [batch * dp_size, vocab] output OOMs under DP attention with a +# tight mem_fraction_static. +_in_autotune_dummy_run = False + + +def get_in_autotune_dummy_run() -> bool: + return _in_autotune_dummy_run + + +@contextmanager +def autotune_dummy_run_mode(): + global _in_autotune_dummy_run + _in_autotune_dummy_run = True + try: + yield + finally: + _in_autotune_dummy_run = False + @dataclasses.dataclass class LogitsProcessorOutput: @@ -300,6 +322,12 @@ def forward( multi_item_delimiter_indices = logits_metadata.multi_item_delimiter_indices logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) + # Autotune dummy run discards this output; see _in_autotune_dummy_run. + # Placed before the MIS / DLLM / common dispatch so all three LM-head + # paths are skipped. + if _in_autotune_dummy_run: + return LogitsProcessorOutput(next_token_logits=None) + # Multi-item scoring only for prefill-only requests with pre-computed indices. if multi_item_delimiter_indices is not None and logits_metadata.is_prefill_only: return self.compute_logprobs_for_multi_item_scoring( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6d9db720a953..e8a7b330a735 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2380,6 +2380,8 @@ def _flashinfer_autotune(self): """Run flashinfer autotune.""" from flashinfer.autotuner import autotune + from sglang.srt.layers.logits_processor import autotune_dummy_run_mode + cache_path = self._flashinfer_autotune_cache_path() if envs.SGLANG_FLASHINFER_AUTOTUNE_CACHE.get(): autotune_cache = cache_path @@ -2401,7 +2403,11 @@ def _flashinfer_autotune(self): # calls on default stream (unsupported by CUDA) when --enable-symm-mem is used. self.forward_stream.wait_stream(torch.cuda.current_stream()) with torch.get_device_module(self.device).stream(self.forward_stream): - with torch.inference_mode(), autotune(True, cache=str(autotune_cache)): + with ( + torch.inference_mode(), + autotune(True, cache=str(autotune_cache)), + autotune_dummy_run_mode(), + ): self._dummy_run(batch_size=self.req_to_token_pool.size) torch.cuda.current_stream().wait_stream(self.forward_stream) logger.info("FlashInfer autotune completed.")