From b0da694c703a589a7b3e81276711c1c3b663768a Mon Sep 17 00:00:00 2001 From: "luoyuan.luo" Date: Sat, 27 Sep 2025 14:53:19 +0800 Subject: [PATCH] Support torch compile based pass manager --- .../sglang/compilation/collective_fusion.py | 385 ++++++++++++++++++ python/sglang/compilation/inductor_pass.py | 125 ++++++ python/sglang/compilation/pass_manager.py | 105 +++++ .../compilation/sglang_inductor_pass.py | 100 +++++ .../sglang/srt/configs/compilation_config.py | 86 ++++ python/sglang/srt/managers/tp_worker.py | 11 + .../srt/model_executor/cuda_graph_runner.py | 40 +- .../sglang/srt/model_executor/model_runner.py | 25 +- python/sglang/srt/utils/common.py | 9 + test/srt/test_forward_split_prefill.py | 2 + test/srt/test_vlm_accuracy.py | 2 + 11 files changed, 886 insertions(+), 4 deletions(-) create mode 100644 python/sglang/compilation/collective_fusion.py create mode 100644 python/sglang/compilation/inductor_pass.py create mode 100644 python/sglang/compilation/pass_manager.py create mode 100644 python/sglang/compilation/sglang_inductor_pass.py create mode 100644 python/sglang/srt/configs/compilation_config.py diff --git a/python/sglang/compilation/collective_fusion.py b/python/sglang/compilation/collective_fusion.py new file mode 100644 index 000000000000..1094c6de4c89 --- /dev/null +++ b/python/sglang/compilation/collective_fusion.py @@ -0,0 +1,385 @@ +import logging +from importlib.util import find_spec +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized_v2 +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from sglang.srt.configs.compilation_config import CompilationConfig +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + tensor_model_parallel_all_reduce, +) +from sglang.srt.utils import ( + direct_register_custom_op, + is_cpu, + is_cuda, + is_flashinfer_available, + is_hip, + is_npu, + is_xpu, +) + +_is_cuda = is_cuda() +_is_flashinfer_available = is_flashinfer_available() + + +if _is_cuda: + if _is_flashinfer_available: + from flashinfer.norm import fused_add_rmsnorm + else: + from sgl_kernel import fused_add_rmsnorm + +from .inductor_pass import enable_fake_mode +from .sglang_inductor_pass import SglangInductorPass, SglangPatternMatcherPass + +if find_spec("flashinfer"): + try: + import flashinfer.comm as flashinfer_comm + + flashinfer_comm = ( + flashinfer_comm + if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") + else None + ) + except ImportError: + flashinfer_comm = None +else: + flashinfer_comm = None + +logger = logging.getLogger(__name__) + + +class BasePattern: + + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +if flashinfer_comm is not None: + _FI_WORKSPACE_TENSOR = None + + MiB = 1024 * 1024 + # Max size of the input tensor per world size + # to use flashinfer fused allreduce + _FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: MiB, # 1MB + 6: MiB // 2, # 512KB + 8: MiB // 2, # 512KB + } + + # opt for a more conservative default value + # when world size is not in _FI_MAX_SIZES + _DEFAULT_FI_MAX_SIZE = MiB // 2 + + def call_trtllm_fused_allreduce_norm( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, + ) -> None: + num_tokens, hidden_size = allreduce_in.shape + element_size = allreduce_in.element_size() + current_tensor_size = num_tokens * hidden_size * element_size + max_fusion_size = max_token_num * hidden_size * element_size + use_flashinfer = current_tensor_size <= min( + _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), + max_fusion_size, + ) + if use_flashinfer: + assert ( + _FI_WORKSPACE_TENSOR is not None + ), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual + else: + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=True, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in sglang we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) + else: + # TODO + pass + + def call_trtllm_fused_allreduce_norm_fake( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, + ) -> None: + pass + + direct_register_custom_op( + op_name="flashinfer_allreduce_residual_rmsnorm", + op_func=call_trtllm_fused_allreduce_norm, + mutates_args=[ + "allreduce_in", + "residual", + "norm_out", + "quant_out", + "scale_out", + ], + fake_impl=call_trtllm_fused_allreduce_norm_fake, + ) + flashinfer_allreduce_residual_rmsnorm = ( + torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm.default + ) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + fuse_rms_quant: bool = False, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.use_oneshot = False + self.max_token_num = max_token_num + self.fuse_rms_quant = fuse_rms_quant + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + "max_token_num": self.max_token_num, + "fuse_rms_quant": self.fuse_rms_quant, + } + + +class AllReduceFusedAddRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. + """ + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [ + residual, + input, + weight, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = auto_functionalized_v2( + torch.ops.sgl_kernel.fused_add_rmsnorm.default, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + # input, residual + return rms[1], rms[2] + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ): + allreduce = auto_functionalized_v2( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # allreduce_in, residual + return allreduce[1], allreduce[2] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AllReduceFusionPass(SglangPatternMatcherPass): + # TODO(yuan-luo): replace with SglangConfig + def __init__( + self, + compilation_config: CompilationConfig, + model_config: ModelConfig, + device_config: DeviceConfig, + ): + super().__init__(compilation_config, model_config, device_config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + return + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="all_reduce_fusion_pass" + ) + if model_config is None: + return + self.hidden_dim = model_config.get_hidden_size() + self.group = get_tp_group().device_group + rank = get_tensor_model_parallel_rank() + use_fp32_lamport = self.model_dtype == torch.float32 + if flashinfer_comm is None: + logger.warning( + "Flashinfer is not installed or comm module not found, " + "skipping allreduce fusion pass" + ) + return + # Check if the world size is supported + if self.tp_size not in _FI_MAX_SIZES: + logger.warning( + "Flashinfer allreduce fusion is not " "supported for world size %s", + self.tp_size, + ) + return + max_num_token = min( + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) + // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + ) + self.ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=self.tp_size, + max_token_num=max_num_token, + hidden_dim=self.hidden_dim, + group=self.group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + global _FI_WORKSPACE_TENSOR + _FI_WORKSPACE_TENSOR = workspace_tensor + self.allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=self.tp_size, + use_fp32_lamport=use_fp32_lamport, + max_token_num=max_num_token, + # fuse rms norm static fp8 quant fused op + # in fallback path, when we don't use flashinfer + fuse_rms_quant=compilation_config.pass_config.enable_fusion, + ) + + self.register_patterns() + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + AllReduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("AllReduceFusionPass disabled") + return + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def __del__(self): + if getattr(self, "disabled", True): + return + if flashinfer_comm is not None: + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( + self.ipc_handles, self.group + ) diff --git a/python/sglang/compilation/inductor_pass.py b/python/sglang/compilation/inductor_pass.py new file mode 100644 index 000000000000..23a47eb5acb3 --- /dev/null +++ b/python/sglang/compilation/inductor_pass.py @@ -0,0 +1,125 @@ +# Adapted from vLLM Compilation framework. + +import functools +import hashlib +import inspect +import json +import types +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx +from torch._inductor.custom_graph_pass import CustomGraphPass +from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily + +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, (types.FunctionType, type)): + src_str = inspect.getsource(src) + else: + # object instance + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + def is_applicable_for_shape(self, shape: Optional[int]): + return True + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None + ): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new diff --git a/python/sglang/compilation/pass_manager.py b/python/sglang/compilation/pass_manager.py new file mode 100644 index 000000000000..6f8e348f080f --- /dev/null +++ b/python/sglang/compilation/pass_manager.py @@ -0,0 +1,105 @@ +# This design is borrowed from https://blog.vllm.ai/2025/08/20/torch-compile.html + +import logging + +from torch import fx + +from sglang.srt.configs.compilation_config import CompilationConfig +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_cuda, + is_flashinfer_available, + is_hip, + is_npu, + is_xpu, + supports_custom_op, +) + +from .sglang_inductor_pass import SglangInductorPass + +_is_cuda = is_cuda() + +if _is_cuda: + from .collective_fusion import AllReduceFusionPass + +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context + +logger = logging.getLogger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__( + self, + compilation_config: CompilationConfig, + model_config: ModelConfig, + device_config: DeviceConfig, + ): + self.passes: list[InductorPass] = [] + self.configure(compilation_config, model_config, device_config) + + def __call__(self, graph: fx.Graph): + SglangInductorPass.dump_prefix = 0 # reset dump index + + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + SglangInductorPass.dump_prefix += 1 + + # post-cleanup goes before fix_functionalization + # because it requires a functional graph + self.post_cleanup(graph) + SglangInductorPass.dump_prefix += 1 + + # always run fix_functionalization last + SglangInductorPass.dump_prefix = None # Cleanup index + + # TODO: wrap three configs into a SglangConfig + def configure( + self, + compilation_config: CompilationConfig, + model_config: ModelConfig, + device_config: DeviceConfig, + ): + self.pass_config = compilation_config.pass_config + + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [ + AllReduceFusionPass(compilation_config, model_config, device_config) + ] + + # TODO: add more pass for fusion and tp. + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/python/sglang/compilation/sglang_inductor_pass.py b/python/sglang/compilation/sglang_inductor_pass.py new file mode 100644 index 000000000000..eafbbb32ea4b --- /dev/null +++ b/python/sglang/compilation/sglang_inductor_pass.py @@ -0,0 +1,100 @@ +import functools +import logging +import operator +import time +from pathlib import Path +from typing import ClassVar, Optional + +import regex as re +import torch +from torch._dynamo.utils import lazy_format_graph_code +from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter + +from sglang.srt.configs.compilation_config import CompilationConfig +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.model_config import ModelConfig + +from .inductor_pass import InductorPass + +logger = logging.getLogger(__name__) + + +class SglangInductorPass(InductorPass): + """ + An inductor pass with access to SGLang PassConfig. + It provides timing, logging, and dumping utilities. + """ + + dump_prefix: ClassVar[Optional[int]] = None + + def __init__( + self, + compilation_config: CompilationConfig, + model_config: ModelConfig, + device_config: DeviceConfig, + ): + self.pass_config = compilation_config.pass_config + self.model_dtype = model_config.dtype if model_config else None + self.device = device_config.device if device_config else None + self.pass_name = self.__class__.__name__ + + @staticmethod + def time_and_log(call_fn): + + @functools.wraps(call_fn) + def wrapped(self: SglangInductorPass, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before") + call_fn(self, graph) + self.dump_graph(graph, "after") + self.end_and_log() + + return wrapped + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + i = SglangInductorPass.dump_prefix + i_str = "" if i is None else f".{i}" + lazy_format_graph_code( + f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module + ) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class SglangPatternMatcherPass(SglangInductorPass): + """ + A SglangInductorPass that uses the Inductor pattern matcher. + Its main use is providing the dump_patterns utility that dumps the + Inductor pattern matcher patterns into a file, which greatly aids debugging. + + """ + + matched_count: int = 0 + """The number of matched patterns in the pass.""" + + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( + r"" + ) + + def _replace_op_overloads(self, string: str) -> str: + """Replace with nicer formulations""" + return self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) + + def dump_patterns( + self, + compilation_config: CompilationConfig, + model_config: ModelConfig, + pm_pass: PatternMatcherPass, + ): + """ + TODO(yuan-luo): use pattern object to manually produce pattern graph + """ diff --git a/python/sglang/srt/configs/compilation_config.py b/python/sglang/srt/configs/compilation_config.py new file mode 100644 index 000000000000..caf963597cde --- /dev/null +++ b/python/sglang/srt/configs/compilation_config.py @@ -0,0 +1,86 @@ +import logging +from dataclasses import asdict, dataclass, field +from typing import Optional + +import torch + +from sglang.compilation.inductor_pass import InductorPass +from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + enable_fusion: bool = True + + """Whether to enable flashinfer allreduce fusion.""" + enable_fi_allreduce_fusion: bool = True + + """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_token_num: int = 16384 + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Any future fields that don't affect compilation should be excluded. + """ + return InductorPass.hash_dict(asdict(self)) + + +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + + - Top-level Compilation control: + -[''] TODO + - CudaGraph capture: + -[''] TODO + - Inductor compilation: + -[''] TODO + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ + + # Top-level Compilation control + level: Optional[int] = None + + custom_ops: list[str] = field(default_factory=list) + + # Inductor capture + use_inductor: bool = True + + inductor_compile_config: dict = field(default_factory=dict) + + inductor_passes: dict[str, str] = field(default_factory=dict) + + pass_config: PassConfig = field(default_factory=PassConfig) + + @staticmethod + def from_server_args(server_args: ServerArgs): + # TODO(yuan-luo): Complete the Config. + return CompilationConfig( + level=CompilationLevel.DYNAMO_AS_IS, + ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 051df74d7247..7be79c55ff6e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -19,6 +19,8 @@ import torch +from sglang.srt.configs.compilation_config import CompilationConfig +from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.managers.io_struct import ( @@ -92,8 +94,17 @@ def __init__( is_draft_model=is_draft_worker, ) + # TODO(yuan-luo) hard-code compilation_config and device_config for POC, + # will refine in the former version. + self.compilation_config = CompilationConfig.from_server_args( + server_args, + ) + self.device_config = DeviceConfig("cuda") + self.model_runner = ModelRunner( model_config=self.model_config, + compilation_config=self.compilation_config, + device_config=self.device_config, mem_fraction_static=server_args.mem_fraction_static, gpu_id=gpu_id, tp_rank=tp_rank, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 8bfb077f9c69..f57332759b5a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -118,6 +118,23 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): _to_torch(sub, reverse, num_tokens) +def _torch_compile_wrapper(forward): + return torch.compile( + torch.no_grad()(forward), + mode=os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"), + dynamic=False, + ) + + +def torch_compile( + model: torch.nn.Module, compilation_config, model_config, device_config +): + set_torch_compile_config(compilation_config, model_config, device_config) + _to_torch(model, reverse=False, num_tokens=1) + model.forward = _torch_compile_wrapper(model.forward) + _to_torch(model, reverse=True, num_tokens=1) + + @contextmanager def patch_model( model: torch.nn.Module, @@ -151,7 +168,11 @@ def patch_model( tp_group.ca_comm = backup_ca_comm -def set_torch_compile_config(): +def set_torch_compile_config( + compilation_config: CompilationConfig, + model_config: ModelConfig, + device_config: DeviceConfig, +): import torch._dynamo.config import torch._inductor.config @@ -159,6 +180,17 @@ def set_torch_compile_config(): torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + # TODO: Add server_args enable_post_grad_pass + if True: + from sglang.srt.compilation.pass_manager import PostGradPassManager + + pass_manager = PostGradPassManager( + compilation_config, model_config, device_config + ) + + # TODO(yuan-luo): handling torch._inductor.compile_fx + torch._inductor.config.post_grad_custom_post_pass = pass_manager + # FIXME: tmp workaround torch._dynamo.config.accumulated_cache_size_limit = 1024 if hasattr(torch._dynamo.config, "cache_size_limit"): @@ -281,7 +313,11 @@ def __init__(self, model_runner: ModelRunner): ) if self.enable_torch_compile: - set_torch_compile_config() + set_torch_compile_config( + self.model_runner.compilation_config, + self.model_runner.model_config, + self.model_runner.device_config, + ) if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e92fe4250f60..3dc990c94fe4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -29,6 +29,7 @@ import torch import torch.distributed as dist +from sglang.srt.configs.compilation_config import CompilationConfig, CompilationLevel from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import ( @@ -106,7 +107,7 @@ SWAKVPool, ) from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, torch_compile from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_loader import get_model @@ -146,6 +147,7 @@ monkey_patch_vllm_gguf_config, set_cuda_arch, slow_rank_detector, + supports_dynamo, ) from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from sglang.srt.weight_sync.tensor_bucket import ( @@ -212,6 +214,8 @@ class ModelRunner: def __init__( self, model_config: ModelConfig, + compilation_config: CompilationConfig, + device_config: DeviceConfig, mem_fraction_static: float, gpu_id: int, tp_rank: int, @@ -239,6 +243,8 @@ def __init__( self.pp_rank = pp_rank self.pp_size = pp_size self.model_config = model_config + self.compilation_config = compilation_config + self.device_config = device_config self.dist_port = nccl_port self.server_args = server_args self.is_draft_worker = is_draft_worker @@ -789,7 +795,7 @@ def load_model(self): self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(self.device, self.gpu_id), + device_config=self.device_config, ) monkey_patch_vllm_parallel_state(reverse=True) monkey_patch_isinstance_for_vllm_base_layer(reverse=True) @@ -853,6 +859,14 @@ def load_model(self): f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." ) from None + if ( + self.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() + ): + # TODO(yuan-luo): compile model in fullgraph + # self.model.compile(fullgraph=True, backend=backend) + return + def update_expert_location( self, new_expert_location_metadata: ExpertLocationMetadata, @@ -1825,6 +1839,13 @@ def init_device_graphs(self): return if self.device != "cpu" and self.server_args.disable_cuda_graph: + if self.server_args.enable_torch_compile: + torch_compile( + self.model, + self.compilation_config, + self.model_config, + self.device_config, + ) return if self.device == "cpu" and not self.server_args.enable_torch_compile: diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0ab2783c3597..b3c945049ed5 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -80,6 +80,7 @@ import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from packaging.version import Version from PIL import Image from starlette.routing import Mount from torch import nn @@ -2213,6 +2214,14 @@ def set_cuda_arch(): os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" +# Using dynamo with SGLang doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") + + def next_power_of_2(n: int): return 1 << (n - 1).bit_length() if n > 0 else 1 diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index 314e35ec9724..b8bdc9db9c83 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -12,6 +12,7 @@ import numpy as np import torch +from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -50,6 +51,7 @@ def setUpClass(cls): # Load model and tokenizer cls.model_config = ModelConfig.from_server_args(cls.server_args) + cls.device_config = DeviceConfig("cuda") cls.model_runner = ModelRunner( model_config=cls.model_config, mem_fraction_static=cls.server_args.mem_fraction_static, diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index ef9a2ad51b09..f2fcdc4c5b67 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -12,6 +12,7 @@ from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer +from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache @@ -146,6 +147,7 @@ def get_processor_output(self, req: Optional[ChatCompletionRequest] = None): def get_sglang_model(self): self.model_runner = ModelRunner( model_config=ModelConfig(self.model_path, model_override_args="{}"), + device_config=DeviceConfig("cuda"), mem_fraction_static=0.8, gpu_id=0, tp_rank=0,