diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 38f3479992c6..abeed1c831d2 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -20,6 +20,7 @@ from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.utils.common import rank0_log logger = logging.getLogger(__name__) @@ -357,6 +358,7 @@ def __init__( config: CompilationConfig, graph_pool: Any, ): + rank0_log(f"Initializing SGLangBackend") assert graph_pool is not None self.graph_pool = graph_pool @@ -375,6 +377,7 @@ def configure_post_pass(self): self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + rank0_log(f"SGLangBackend __call__") base_cache_dir = os.path.expanduser( os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/") ) @@ -441,7 +444,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: with open(graph_path, "w") as f: f.write(src) - logger.debug("Computation graph saved to %s", graph_path) + rank0_log(f"Computation graph saved to {graph_path}") self._called = True return self.split_gm diff --git a/python/sglang/srt/compilation/compile.py b/python/sglang/srt/compilation/compile.py index e0ec9a1f327e..b9ff7f6bdb93 100644 --- a/python/sglang/srt/compilation/compile.py +++ b/python/sglang/srt/compilation/compile.py @@ -11,6 +11,7 @@ import torch from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.utils.common import rank0_log logger = logging.getLogger(__name__) @@ -129,6 +130,7 @@ def install_torch_compiled( fullgraph: bool = True, graph_pool: Any = None, ): + rank0_log(f"install_torch_compiled") unbound_fwd = module.__class__.forward if not callable(unbound_fwd): raise TypeError("module.__class__.forward must be callable") diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 52f50aad381a..95f486abe5c3 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -59,6 +59,27 @@ from sglang.srt.model_executor.model_runner import ModelRunner +@contextmanager +def disable_ca_comm(tp_group): + """ + Context manager to temporarily disable custom allreduce communication. + + This is used during Piecewise CUDA graph capture to avoid custom allreduce operations + that may not be compatible with graph capture. + + TODO(yuwei): Fix this + """ + old_disabled = None + try: + if tp_group.ca_comm is not None: + old_disabled = tp_group.ca_comm.disabled + tp_group.ca_comm.disabled = True + yield + finally: + if tp_group.ca_comm is not None and old_disabled is not None: + tp_group.ca_comm.disabled = old_disabled + + @contextmanager def freeze_gc(enable_cudagraph_gc: bool): """ @@ -207,7 +228,7 @@ def __init__(self, model_runner: ModelRunner): ) with set_compiled(True): - self.warmup_and_capture() + self.warmup_torch_compile() # Capture try: @@ -219,7 +240,8 @@ def __init__(self, model_runner: ModelRunner): self.raw_num_tokens = 0 - def warmup_and_capture(self): + def warmup_torch_compile(self): + """Warmup the model with a simple forward pass before CUDA graph capture.""" num_tokens = 2 with torch.device(self.device): forward_batch = ForwardBatch( @@ -283,7 +305,7 @@ def warmup_and_capture(self): with set_forward_context( forward_batch, self.attention_layers, self.quant_config - ): + ), disable_ca_comm(self.model_runner.tp_group): _ = self.model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, @@ -311,10 +333,9 @@ def capture(self) -> None: # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc): - if self.model_runner.tp_group.ca_comm is not None: - old_ca_disable = self.model_runner.tp_group.ca_comm.disabled - self.model_runner.tp_group.ca_comm.disabled = True + with freeze_gc( + self.model_runner.server_args.enable_cudagraph_gc + ), disable_ca_comm(self.model_runner.tp_group): avail_mem = get_available_gpu_memory( self.model_runner.device, self.model_runner.gpu_id, @@ -342,8 +363,6 @@ def capture(self) -> None: # Save gemlite cache after each capture save_gemlite_cache() - if self.model_runner.tp_group.ca_comm is not None: - self.model_runner.tp_group.ca_comm.disabled = old_ca_disable def capture_one_batch_size(self, num_tokens: int): bs = 1 @@ -565,10 +584,7 @@ def replay( forward_batch: ForwardBatch, **kwargs, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - with enable_piecewise_cuda_graph(): - if self.model_runner.tp_group.ca_comm is not None: - old_ca_disable = self.model_runner.tp_group.ca_comm.disabled - self.model_runner.tp_group.ca_comm.disabled = True + with enable_piecewise_cuda_graph(), disable_ca_comm(self.model_runner.tp_group): self.model_runner.attn_backend.init_forward_metadata(forward_batch) static_forward_batch = self.replay_prepare(forward_batch, **kwargs) # Replay @@ -599,8 +615,6 @@ def replay( raise NotImplementedError( "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet." ) - if self.model_runner.tp_group.ca_comm is not None: - self.model_runner.tp_group.ca_comm.disabled = old_ca_disable def get_spec_info(self, num_tokens: int): spec_info = None diff --git a/python/sglang/srt/tokenizer/tiktoken_tokenizer.py b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py index c1f2a91b0946..b29015547276 100644 --- a/python/sglang/srt/tokenizer/tiktoken_tokenizer.py +++ b/python/sglang/srt/tokenizer/tiktoken_tokenizer.py @@ -127,6 +127,7 @@ def apply_chat_template( add_generation_prompt, tools=None, reasoning_effort=None, + **kwargs, # Accept additional parameters (e.g., return_dict) for compatibility ): ret = self.chat_template_jinja.render( messages=messages, add_generation_prompt=add_generation_prompt