Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -62,6 +63,27 @@

_is_npu = is_npu()

# When set, LogitsProcessor.forward returns an empty output and skips the
# LM head + tensor-parallel all-gather. FlashInfer autotune only profiles
# attention/MoE/GEMM kernels, so the LM-head all-gather is wasted work --
# and its [batch * dp_size, vocab] output OOMs under DP attention with a
# tight mem_fraction_static.
_in_autotune_dummy_run = False


def get_in_autotune_dummy_run() -> bool:
return _in_autotune_dummy_run


@contextmanager
def autotune_dummy_run_mode():
global _in_autotune_dummy_run
_in_autotune_dummy_run = True
try:
yield
finally:
_in_autotune_dummy_run = False


@dataclasses.dataclass
class LogitsProcessorOutput:
Expand Down Expand Up @@ -297,6 +319,12 @@ def forward(
multi_item_delimiter_indices = logits_metadata.multi_item_delimiter_indices
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

# Autotune dummy run discards this output; see _in_autotune_dummy_run.
# Placed before the MIS / DLLM / common dispatch so all three LM-head
# paths are skipped.
if _in_autotune_dummy_run:
return LogitsProcessorOutput(next_token_logits=None)

# Multi-item scoring only for prefill-only requests with pre-computed indices.
if multi_item_delimiter_indices is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,13 +2258,15 @@ def _flashinfer_autotune(self):
"""Run flashinfer autotune."""
from flashinfer.autotuner import autotune

from sglang.srt.layers.logits_processor import autotune_dummy_run_mode

logger.info("Running FlashInfer autotune...")

# Run warmup on the non-default stream to avoid NCCL 2.29+ cudaMemcpyBatchAsync
# calls on default stream (unsupported by CUDA) when --enable-symm-mem is used.
self.forward_stream.wait_stream(torch.cuda.current_stream())
with torch.get_device_module(self.device).stream(self.forward_stream):
with torch.inference_mode(), autotune():
with torch.inference_mode(), autotune(), autotune_dummy_run_mode():
self._dummy_run(
batch_size=self.req_to_token_pool.size, run_ctx=autotune()
)
Expand Down
Loading