diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 8af025707f55..c6b7510a6540 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backends.py import ast @@ -19,8 +19,10 @@ from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend +from sglang.srt.compilation.inductor_pass import InductorPass from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager +from sglang.srt.compilation.sglang_config import SGLangConfig from sglang.srt.utils.common import is_npu, rank0_log logger = logging.getLogger(__name__) @@ -103,21 +105,24 @@ def load( graph_index: int, runtime_shape: Optional[int] = None, ) -> Optional[Callable]: - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + key = (runtime_shape, graph_index, self.compiler.name) + handle = self.cache.get(key, None) + if handle is None: + return None + compiled_graph = self.compiler.load( handle, graph, example_inputs, graph_index, runtime_shape ) if runtime_shape is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via " - "handle %s", + "Directly load the %s-th graph for dynamic shape from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via " "handle %s", + "Directly load the %s-th graph for shape %s from %s via handle %s", graph_index, str(runtime_shape), self.compiler.name, @@ -130,6 +135,7 @@ def compile( graph: fx.GraphModule, example_inputs, inductor_config: dict[str, Any], + compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, runtime_shape: Optional[int] = None, @@ -143,7 +149,29 @@ def compile( compiled_graph = None - # TODO(Yuwei): support cache loading + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + if compiled_graph is not None: + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info( + "Directly load the compiled graph(s) for dynamic shape " + "from the cache, took %.3f s", + elapsed, + ) + else: + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", + str(runtime_shape), + elapsed, + ) + return compiled_graph # no compiler cached the graph, or the cache is disabled, # we need to compile it @@ -190,6 +218,7 @@ def compile( if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if runtime_shape is None: logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) else: @@ -256,20 +285,27 @@ def split_graph( return split_gm, outputs -# we share the global graph pool among all the backends -global_graph_pool = None - compilation_start_time = 0.0 class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. + """ + def __init__( self, module: torch.fx.GraphModule, compile_submod_names: list[str], - inductor_config: dict[str, Any], graph_pool, - compile_config: CompilationConfig, + sglang_config: SGLangConfig, sglang_backend: "SGLangBackend", ): super().__init__(module) @@ -281,8 +317,8 @@ def __init__( self.sglang_backend = sglang_backend # When True, it annoyingly dumps the torch.fx.Graph on errors. self.extra_traceback = False - self.inductor_config = inductor_config - self.compile_config = compile_config + self.sglang_config = sglang_config + self.compilation_config = sglang_config.compilation_config def run(self, *args): fake_args = [ @@ -312,7 +348,8 @@ def call_module( self.sglang_backend.compiler_manager.compile( submod, args, - self.inductor_config, + self.sglang_backend.inductor_config, + self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None, @@ -321,8 +358,8 @@ def call_module( self.module.__dict__[target] = make_backend( submod, - self.compile_config, - self.inductor_config, + self.compilation_config, + self.sglang_backend.inductor_config, self.graph_pool, index, len(self.compile_submod_names), @@ -355,7 +392,19 @@ def set_model_tag(tag: str): class SGLangBackend: + """The compilation backend for `torch.compile` with SGLang. + It is used for compilation mode of `CompilationMode.SGLANG_COMPILE`, + where we customize the compilation. + + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. + """ + + sglang_config: SGLangConfig + compilation_config: CompilationConfig graph_pool: Any _called: bool = False # the graph we compiled @@ -372,7 +421,7 @@ class SGLangBackend: def __init__( self, - config: CompilationConfig, + sglang_config: SGLangConfig, graph_pool: Any, ): rank0_log(f"Initializing SGLangBackend") @@ -383,15 +432,31 @@ def __init__( self.sym_tensor_indices = [] self.input_buffers = [] - self.compiler_manager = CompilerManager(config) + self.sglang_config = sglang_config + self.compilation_config = sglang_config.compilation_config + + self.compiler_manager = CompilerManager(self.compilation_config) self.inductor_config = { "enable_auto_functionalized_v2": False, } - self.compile_config = config def configure_post_pass(self): - self.post_grad_pass_manager.configure() - self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager + config = self.compilation_config + self.post_grad_pass_manager.configure(self.sglang_config) + + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. + inductor_config = config.inductor_compile_config + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + # PassManager already added to config + pass + else: + # Config should automatically wrap all inductor passes + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: rank0_log(f"SGLangBackend __call__") @@ -423,7 +488,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.split_gm, self.piecewise_graphs = split_graph( graph, - self.compile_config.split_ops, + self.compilation_config.split_ops, ) from torch._dynamo.utils import lazy_format_graph_code @@ -443,9 +508,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: PiecewiseCompileInterpreter( self.split_gm, submod_names_to_compile, - self.inductor_config, self.graph_pool, - self.compile_config, + self.sglang_config, self, ).run(*example_inputs) diff --git a/python/sglang/srt/compilation/collective_fusion.py b/python/sglang/srt/compilation/collective_fusion.py new file mode 100644 index 000000000000..39a8c01f8404 --- /dev/null +++ b/python/sglang/srt/compilation/collective_fusion.py @@ -0,0 +1,459 @@ +import logging +from importlib.util import find_spec + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass + +from sglang.srt.compilation.inductor_pass import enable_fake_mode +from sglang.srt.compilation.matcher_utils import MatcherFusedAddRMSNorm, MatcherRMSNorm +from sglang.srt.compilation.sglang_config import SGLangConfig +from sglang.srt.compilation.sglang_inductor_pass import SGLangPatternMatcherPass +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + tensor_model_parallel_all_reduce, +) +from sglang.srt.utils import direct_register_custom_op + +if find_spec("flashinfer"): + try: + import flashinfer.comm as flashinfer_comm + + flashinfer_comm = ( + flashinfer_comm + if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") + else None + ) + except ImportError: + flashinfer_comm = None +else: + flashinfer_comm = None + +logger = logging.getLogger(__name__) + + +class BasePattern: + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +if flashinfer_comm is not None: + _FI_WORKSPACE_TENSOR = None + + MiB = 1024 * 1024 + # Max size of the input tensor per world size + # to use flashinfer fused allreduce + _FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: MiB // 2, # 512KB + } + + # opt for a more conservative default value + # when world size is not in _FI_MAX_SIZES + _DEFAULT_FI_MAX_SIZE = MiB // 2 + + def call_flashinfer_fused_allreduce_norm( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, + scale_out: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, + ) -> None: + num_tokens, hidden_size = allreduce_in.shape + element_size = allreduce_in.element_size() + current_tensor_size = num_tokens * hidden_size * element_size + max_fusion_size = max_token_num * hidden_size * element_size + use_flashinfer = current_tensor_size <= min( + _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), + max_fusion_size, + ) + if use_flashinfer: + assert ( + _FI_WORKSPACE_TENSOR is not None + ), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual + else: + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=True, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in sglang we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) + else: + allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) + if scale_factor is not None and scale_out is None and fuse_rms_quant: + # Do fused rms norm static fp8 quant fused op + if norm_out is None: + torch.ops._C.fused_add_rms_norm_static_fp8_quant( + quant_out, + allreduce_out, + residual, + rms_gamma, + scale_factor, + rms_eps, + ) + else: + torch.ops._C.rms_norm_static_fp8_quant( + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) + else: + if norm_out is None: + torch.ops._C.fused_add_rms_norm( + allreduce_out, residual, rms_gamma, rms_eps + ) + norm_out = allreduce_out + else: + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) + if scale_factor is not None: + if scale_out is not None: + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) + else: + torch.ops._C.static_scaled_fp8_quant( + quant_out, norm_out, scale_factor + ) + if scale_factor is None or norm_out is not None: + # we need to return allreduce output + # in cases of non quant fused AR + RMS norm + # and fused AR + RMS norm + quant without fused add + allreduce_in.copy_(allreduce_out) + + def call_flashinfer_fused_allreduce_norm_fake( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, + scale_out: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, + ) -> None: + pass + + direct_register_custom_op( + op_name="flashinfer_trtllm_fused_allreduce_norm", + op_func=call_flashinfer_fused_allreduce_norm, + mutates_args=[ + "allreduce_in", + "residual", + "norm_out", + "quant_out", + "scale_out", + ], + fake_impl=call_flashinfer_fused_allreduce_norm_fake, + ) + flashinfer_trtllm_fused_allreduce_norm = ( + # TODO(yuan-luo): SGLang kernel + torch.ops.sglang.flashinfer_trtllm_fused_allreduce_norm.default + ) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + fuse_rms_quant: bool = False, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.use_oneshot = False + self.max_token_num = max_token_num + self.fuse_rms_quant = fuse_rms_quant + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + "max_token_num": self.max_token_num, + "fuse_rms_quant": self.fuse_rms_quant, + } + + +class AllReduceRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (without residual) + with fused flashinfer implementation. + Applies to allreduce + rmsnorm before attn in the first Transformer block. + """ + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + + def get_inputs(self): + input, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = self.rmsnorm_matcher(allreduce_output, weight) + + return rms, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): + residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=rms_result, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # rms_result, allreduce_in + return allreduce[3], allreduce[1] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AllReduceFusedAddRMSNormPattern(BasePattern): + """ + This pattern replaces the allreduce + rms norm (with residual) + with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. + """ + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + + def get_inputs(self): + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + return rms, residual + + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # allreduce_in, residual + return allreduce[1], allreduce[2] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + # Same pattern, but only return the output and not residual + # (helpful for end of graph where residual is not used again) + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + + pm.register_replacement( + first_return_only(pattern), + first_return_only(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + + +class AllReduceFusionPass(SGLangPatternMatcherPass): + def __init__(self, config: SGLangConfig): + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + return + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="all_reduce_fusion_pass" + ) + if config.model_config is None: + return + self.hidden_dim = config.model_config.get_hidden_size() + self.group = get_tp_group().device_group + rank = get_tensor_model_parallel_rank() + use_fp32_lamport = self.model_dtype == torch.float32 + if flashinfer_comm is None: + logger.warning( + "Flashinfer is not installed or comm module not found, " + "skipping allreduce fusion pass" + ) + return + # Check if the world size is supported + if self.tp_size not in _FI_MAX_SIZES: + logger.warning( + "Flashinfer allreduce fusion is not supported for world size %s", + self.tp_size, + ) + return + max_num_token = min( + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) + // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + ) + self.ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=self.tp_size, + max_token_num=max_num_token, + hidden_dim=self.hidden_dim, + group=self.group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + global _FI_WORKSPACE_TENSOR + _FI_WORKSPACE_TENSOR = workspace_tensor + self.allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=self.tp_size, + use_fp32_lamport=use_fp32_lamport, + max_token_num=max_num_token, + # fuse rms norm static fp8 quant fused op + # in fallback path, when we don't use flashinfer + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + ) + + self.register_patterns() + + @enable_fake_mode + def register_patterns(self): + for epsilon in [1e-5, 1e-6]: + # TODO(yuan-luo): + # register AllReduceFusedRMSNormStaticQuantFP8Pattern + # register AllReduceFusedAddRMSNormStaticQuantFP8Pattern + # register AllReduceFusedRMSNormStaticQuantNVFP4Pattern + # register AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern + AllReduceRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + # AllReduceFusedAddRMSNormPattern( + # epsilon, + # self.model_dtype, + # self.device, + # self.allreduce_params, + # ).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + def __call__(self, graph: fx.Graph): + if self.disabled: + logger.debug("AllReduceFusionPass disabled") + return + + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def __del__(self): + if getattr(self, "disabled", True): + return + if flashinfer_comm is not None: + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( + self.ipc_handles, self.group + ) diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index d7687aaf1c93..6d66a4589d19 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -1,20 +1,97 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/config/compilation.py -from typing import List +import logging +from dataclasses import asdict, dataclass, field +from typing import List, Optional +from sglang.srt.compilation.inductor_pass import InductorPass -# TODO(Yuwei): support better compile config support +logger = logging.getLogger(__name__) + + +class CompilationMode: + """The compilation approach used for torch.compile-based compilation of the + model.""" + + NONE = 0 + """No torch.compile compilation is applied, model runs in fully eager pytorch mode. + The model runs as-is.""" + STOCK_TORCH_COMPILE = 1 + """The standard `torch.compile` compilation pipeline.""" + DYNAMO_TRACE_ONCE = 2 + """Single Dynamo trace through the model, avoiding recompilation.""" + SGLANG_COMPILE = 3 + """Custom SGLang Inductor-based backend with caching, piecewise compilation, + shape specialization, and custom passes.""" + + +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + enable_fusion: bool = True + + """Whether to enable flashinfer allreduce fusion.""" + enable_fi_allreduce_fusion: bool = True + + """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_token_num: int = 16384 + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Any future fields that don't affect compilation should be excluded. + """ + return InductorPass.hash_dict(asdict(self)) + + +@dataclass class CompilationConfig: - def __init__( - self, - capture_sizes: List[int], - compiler: str = "eager", - enable_debug_mode: bool = False, - ): - self.traced_files = set() - self.capture_sizes = capture_sizes - self.compiler = compiler - self.enable_debug_mode = enable_debug_mode + + capture_sizes: Optional[List[int]] = None + + compiler: str = "eager" + + enable_debug_mode: bool = False + + traced_files: set[str] = field(default_factory=set, init=False) + + split_ops: list[str] = field(default_factory=list, init=False) + + # Top-level Compilation control + level: Optional[int] = None + + mode: CompilationMode | None = None + + # The backend for compilation. It needs to be a string: + # (empty string): use the default backend ("inductor" on CUDA-alike + # platforms). + backend: str = "" + + custom_ops: list[str] = field(default_factory=list) + + # Inductor capture + use_inductor: bool = True + + splitting_ops: list[str] | None = None + + use_inductor_graph_partition: bool = False + + inductor_compile_config: dict = field(default_factory=dict) + + inductor_passes: dict[str, str] = field(default_factory=dict) + + pass_config: PassConfig = field(default_factory=PassConfig) + + # time taken for compilation + compilation_time: float = field(default=0.0, init=False) + + def __post_init__(self): self.split_ops = [ "sglang.unified_attention_with_output", "sglang.gdn_with_output", diff --git a/python/sglang/srt/compilation/compile.py b/python/sglang/srt/compilation/compile.py index b9ff7f6bdb93..c1a9534feeef 100644 --- a/python/sglang/srt/compilation/compile.py +++ b/python/sglang/srt/compilation/compile.py @@ -10,7 +10,7 @@ import torch -from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.sglang_config import SGLangConfig from sglang.srt.utils.common import rank0_log logger = logging.getLogger(__name__) @@ -126,7 +126,7 @@ def install_torch_compiled( *, dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None, backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None, - compile_config: CompilationConfig = None, + sglang_config: SGLangConfig = None, fullgraph: bool = True, graph_pool: Any = None, ): @@ -141,7 +141,7 @@ def install_torch_compiled( if backend_factory is None: from sglang.srt.compilation.backend import SGLangBackend - backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( + backend_factory = lambda gm, ex: SGLangBackend(sglang_config, graph_pool)( gm, ex ) diff --git a/python/sglang/srt/compilation/compiler_interface.py b/python/sglang/srt/compilation/compiler_interface.py index 8310f75c936c..cfb2fdec677d 100644 --- a/python/sglang/srt/compilation/compiler_interface.py +++ b/python/sglang/srt/compilation/compiler_interface.py @@ -213,6 +213,7 @@ def compile( current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, runtime_shape) + set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -477,6 +478,10 @@ def set_inductor_config(config, runtime_shape): config["coordinate_descent_tuning"] = True +def set_functorch_config(): + torch._functorch.config.bundled_autograd_cache = False + + class EagerAdapter(CompilerInterface): name = "eager" diff --git a/python/sglang/srt/compilation/fix_functionalization.py b/python/sglang/srt/compilation/fix_functionalization.py index 8673e3576b00..8cff9fc4702a 100644 --- a/python/sglang/srt/compilation/fix_functionalization.py +++ b/python/sglang/srt/compilation/fix_functionalization.py @@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from sglang.srt.compilation.fx_utils import is_func -from sglang.srt.compilation.inductor_pass import SGLangInductorPass +from sglang.srt.compilation.sglang_inductor_pass import SGLangInductorPass logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/compilation/inductor_pass.py b/python/sglang/srt/compilation/inductor_pass.py index acbde65bf8ab..aa0021acf5e5 100644 --- a/python/sglang/srt/compilation/inductor_pass.py +++ b/python/sglang/srt/compilation/inductor_pass.py @@ -1,18 +1,18 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/inductor_pass.py +import functools import hashlib import inspect import json import logging -import time import types from contextlib import contextmanager from typing import Any, Callable, Optional, Union import torch from torch import fx -from torch._dynamo.utils import lazy_format_graph_code from torch._inductor.custom_graph_pass import CustomGraphPass +from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily logger = logging.getLogger(__name__) @@ -111,30 +111,17 @@ def uuid(self) -> Any: return self._uuid -class SGLangInductorPass(InductorPass): - - def __init__( - self, - ): - self.pass_name = self.__class__.__name__ - - def dump_graph(self, graph: torch.fx.Graph, stage: str): - lazy_format_graph_code(stage, graph.owning_module) - - def begin(self): - self._start_time = time.perf_counter_ns() - - def end_and_log(self): - self._end_time = time.perf_counter_ns() - duration_ms = float(self._end_time - self._start_time) / 1.0e6 - logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) - +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ -class PrinterInductorPass(SGLangInductorPass): + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) - def __init__(self, name: str): - super().__init__() - self.name = name + return result - def __call__(self, graph: torch.fx.Graph): - self.dump_graph(graph, self.name) + return fn_new diff --git a/python/sglang/srt/compilation/matcher_utils.py b/python/sglang/srt/compilation/matcher_utils.py new file mode 100644 index 000000000000..b8161a3d6f58 --- /dev/null +++ b/python/sglang/srt/compilation/matcher_utils.py @@ -0,0 +1,160 @@ +from abc import ABC, abstractmethod + +import torch +from torch._higher_order_ops import auto_functionalized + +from sglang.srt.compilation.sglang_config import get_current_sglang_config +from sglang.srt.layers.layernorm import RMSNorm + + +def _get_default_op(ns: str, name: str): + if not hasattr(torch.ops, ns): + return None + mod = getattr(torch.ops, ns) + if not hasattr(mod, name): + return None + pkt = getattr(mod, name) # OpOverloadPacket + if hasattr(pkt, "default"): + return pkt.default # OpOverload (auto_functionalized need) + if hasattr(pkt, "overloads") and pkt.overloads(): + return getattr(pkt, pkt.overloads()[0]) + return None + + +RMS_OP = _get_default_op("sgl_kernel", "rmsnorm") or _get_default_op("_C", "rmsnorm") +RMS_ADD_OP = _get_default_op("sgl_kernel", "fused_add_rmsnorm") or _get_default_op( + "_C", "fused_add_rmsnorm" +) +ENABLE_PDL = False + +# RMS_OP schema: +# sgl_kernel::rmsnorm( +# Tensor($0! -> ) output, +# Tensor input, +# Tensor weight, +# float eps, +# bool enable_pdl +# ) -> () +# +# RMS_ADD_OP schema: +# sgl_kernel::fused_add_rmsnorm( +# Tensor($0! -> ) input, +# Tensor($1! -> ) residual, +# Tensor weight, float eps, +# bool enable_pdl +# ) -> () + + +class MatcherCustomOp(ABC): + def __init__(self, enabled: bool): + config = get_current_sglang_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None + + self.enabled = enabled + self.forward = self.forward_custom if enabled else self.forward_native + + @abstractmethod + def forward_custom(self, *args, **kws): + pass + + @abstractmethod + def forward_native(self, *args, **kws): + pass + + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + + def empty_f32(self, *args, **kws): + return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) + + def inputs(self) -> list[torch.Tensor]: + """Utility for inputs to the pattern""" + raise NotImplementedError + + +class MatcherRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + return [input, weight] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + out = torch.empty_like(input) + ret = auto_functionalized( + RMS_OP, + output=out, + input=input, + weight=weight, + eps=self.epsilon, + enable_pdl=ENABLE_PDL, + ) + + _, out = ret + return out + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight + ) + + +class MatcherFusedAddRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + residual = self.empty(5, 16) + return [input, weight, residual] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + ret = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + eps=self.epsilon, + enable_pdl=ENABLE_PDL, + ) + + _, new_input, new_residual = ret + return new_input, new_residual + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight, residual + ) diff --git a/python/sglang/srt/compilation/pass_manager.py b/python/sglang/srt/compilation/pass_manager.py index 9173976f1878..c936a63c4d23 100644 --- a/python/sglang/srt/compilation/pass_manager.py +++ b/python/sglang/srt/compilation/pass_manager.py @@ -4,13 +4,15 @@ from torch import fx as fx +from sglang.srt.compilation.collective_fusion import AllReduceFusionPass from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass from sglang.srt.compilation.inductor_pass import ( CustomGraphPass, InductorPass, - SGLangInductorPass, get_pass_context, ) +from sglang.srt.compilation.sglang_config import SGLangConfig, set_current_sglang_config +from sglang.srt.compilation.sglang_inductor_pass import SGLangInductorPass logger = logging.getLogger(__name__) @@ -42,11 +44,13 @@ def __call__(self, graph: fx.Graph): # always run fix_functionalization last self.fix_functionalization(graph) - def configure( - self, - ): - self.pass_config = dict() - self.fix_functionalization = FixFunctionalizationPass() + def configure(self, config: SGLangConfig): + self.pass_config = config.compilation_config.pass_config + + with set_current_sglang_config(config, check_compile=False): + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) diff --git a/python/sglang/srt/compilation/sglang_config.py b/python/sglang/srt/compilation/sglang_config.py new file mode 100644 index 000000000000..dd5103d888f1 --- /dev/null +++ b/python/sglang/srt/compilation/sglang_config.py @@ -0,0 +1,119 @@ +import copy +import logging +from contextlib import contextmanager +from dataclasses import replace +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +from transformers import PretrainedConfig + +from sglang.srt.compilation.compilation_config import CompilationConfig, CompilationMode +from sglang.srt.configs.device_config import DeviceConfig + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig +else: + ModelConfig = Any + +logger = logging.getLogger(__name__) + + +class SGLangConfig: + """Dataclass which contains all sglang-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + def __init__( + self, + model_config: ModelConfig = None, + device_config: DeviceConfig = DeviceConfig, + compilation_config: CompilationConfig = CompilationConfig, + ): + self.model_config = model_config + self.device_config = device_config + self.compilation_config = compilation_config + self.self_id = "" + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: list[str] | None = None, + ) -> "SGLangConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + +_current_sglang_config: SGLangConfig | None = None +_current_prefix: str | None = None + + +@contextmanager +def set_current_sglang_config( + sglang_config: SGLangConfig, check_compile=False, prefix: str | None = None +): + """ + Temporarily set the current SGLang config. + Used during model initialization. + We save the current SGLang config in a global variable, + so that all modules can access it, e.g. custom ops + can access the SGLang config to determine how to dispatch. + """ + global _current_sglang_config, _current_prefix + old_sglang_config = _current_sglang_config + old_prefix = _current_prefix + + try: + _current_sglang_config = sglang_config + _current_prefix = prefix + yield + except Exception: + raise + else: + # TODO(): custom op check + if check_compile: + pass + + if ( + check_compile + # TODO(): compilation mode check + and sglang_config.compilation_config.mode == CompilationMode.SGLANG_COMPILE + ): + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " https://github.com/sgl-project/sglang/issues/new/choose" + " if you want it to be supported.", + sglang_config.model_config.model, + ) + finally: + _current_sglang_config = old_sglang_config + _current_prefix = old_prefix + # Clear the compilation config cache when context changes + get_cached_compilation_config.cache_clear() + + +@lru_cache(maxsize=1) +def get_cached_compilation_config(): + """Cache config to avoid repeated calls to get_current_sglang_config()""" + return get_current_sglang_config().compilation_config + + +def get_current_sglang_config() -> SGLangConfig: + if _current_sglang_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the sglang config. In that case, we set a default + # config. + logger.warning("Current SGLang config is not set.") + return SGLangConfig() + return _current_sglang_config diff --git a/python/sglang/srt/compilation/sglang_inductor_pass.py b/python/sglang/srt/compilation/sglang_inductor_pass.py new file mode 100644 index 000000000000..4dcc9030a56e --- /dev/null +++ b/python/sglang/srt/compilation/sglang_inductor_pass.py @@ -0,0 +1,93 @@ +import functools +import logging +import time +from dataclasses import dataclass +from typing import ClassVar + +import regex as re +import torch +from torch._dynamo.utils import lazy_format_graph_code + +from sglang.srt.compilation.inductor_pass import InductorPass +from sglang.srt.compilation.sglang_config import SGLangConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class InductorCompilationConfig: + splitting_ops: list[str] | None = None + use_inductor_graph_partition: bool = False + + +class SGLangInductorPass(InductorPass): + """ + An inductor pass with access to SGLang PassConfig. + It provides timing, logging, and dumping utilities. + """ + + dump_prefix: ClassVar[int | None] = None + """Keep track of pass index for debug dump ordering.""" + + def __init__(self, config: SGLangConfig): + # Get only the necessary CompilationConfig for the inductor pass, since + # full `CompilationConfig` contains pointer to model which is unsafe. + self.compilation_config = InductorCompilationConfig( + splitting_ops=config.compilation_config.splitting_ops, + use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition, + ) + self.pass_config = config.compilation_config.pass_config + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None + self.pass_name = self.__class__.__name__ + + @staticmethod + def time_and_log(call_fn): + @functools.wraps(call_fn) + def wrapped(self: SGLangInductorPass, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before") + call_fn(self, graph) + self.dump_graph(graph, "after") + self.end_and_log() + + return wrapped + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + i = SGLangInductorPass.dump_prefix + i_str = "" if i is None else f".{i}" + lazy_format_graph_code( + f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module + ) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class SGLangPatternMatcherPass(SGLangInductorPass): + """ + A SGLangInductorPass that uses the Inductor pattern matcher. + Its main use is providing the dump_patterns utility that dumps the + Inductor pattern matcher patterns into a file, which greatly aids debugging. + + TODO(yuan-luo): move more utilities to this pass. + """ + + matched_count: int = 0 + """The number of matched patterns in the pass.""" + + _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( + r"" + ) + + def _replace_op_overloads(self, string: str) -> str: + """Replace with nicer formulations""" + return self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f807deedbc10..9d678281e248 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -562,6 +562,9 @@ def get_swa_num_kv_heads(self, tensor_parallel_size) -> int: total_num_kv_heads = self.hf_text_config.swa_num_key_value_heads return max(1, total_num_kv_heads // tensor_parallel_size) + def get_hidden_size(self) -> int: + return getattr(self.hf_text_config, "hidden_size", 0) + # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 90cfa8e12ced..c6f847f263e8 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -4,8 +4,11 @@ TODO: Move this to python/sglang/srt/layers/custom_op.py """ +import logging + from torch import nn +from sglang.srt.compilation.sglang_config import get_cached_compilation_config from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, @@ -22,6 +25,8 @@ _is_npu = is_npu() _is_xpu = is_xpu() +logger = logging.getLogger(__name__) + class CustomOp(nn.Module): def __init__(self): @@ -106,3 +111,38 @@ def dispatch_forward(self): return self.forward_xpu else: return self.forward_native + + @classmethod + def enabled(cls) -> bool: + # if no name, then it was not registered + compilation_config = get_cached_compilation_config() + custom_ops = compilation_config.custom_ops + if not hasattr(cls, "name"): + logger.warning_once( + "Custom op %s was not registered, which means it won't appear " + "in the op registry. It will be enabled/disabled based on the " + "global settings.", + cls.__name__, + ) + return CustomOp.default_on() + + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops + assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" + + return (CustomOp.default_on() or enabled) and not disabled + + @staticmethod + def default_on() -> bool: + """ + Behavior controlled by `CompilationConfig.custom_ops`: On by default if + 'all', off by default if 'none'. + When PyTorch Inductor is used, 'none' is the default value, + otherwise 'all'. + """ + compilation_config = get_cached_compilation_config() + count_none = compilation_config.custom_ops.count("none") + count_all = compilation_config.custom_ops.count("all") + # assert count_none + count_all == 1 + + return not count_none > 0 or count_all > 0 diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 68f94a0908a7..23e839d16cbb 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -222,6 +222,52 @@ def forward_native( else: return x, residual + @staticmethod + def forward_static( + x: torch.Tensor, + variance_epsilon: float, + hidden_size: int, + orig_dtype: torch.dtype, + weight: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + variance_size_override: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + x = x.to(torch.float32) + if residual is not None: + # residual promoted f16->f32 automatically, + # otherwise Inductor eliminates the casts to and from f16, + # increasing memory usage (and complicating pattern matching) + x = x + residual + residual = x.to(orig_dtype) + + if x.shape[-1] != hidden_size: + raise ValueError( + f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" + ) + + if variance_size_override is None: + x_var = x + else: + if hidden_size < variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{variance_size_override}, but found: {hidden_size}" + ) + + x_var = x[:, :, :variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + variance_epsilon) + x = x.to(orig_dtype) + if weight is not None: + x = x * weight + if residual is None: + return x + else: + return x, residual + def forward_cpu( self, x: torch.Tensor, diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 2dcd59ceb133..5da1588fa557 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -33,6 +33,8 @@ set_forward_context, set_pcg_capture_stream, ) +from sglang.srt.compilation.sglang_config import SGLangConfig +from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -151,15 +153,23 @@ def __init__(self, model_runner: ModelRunner): "inductor", ], "By now, only eager and inductor are supported for piecewise cuda graph compiler." self.compile_config = CompilationConfig( - self.model_runner.server_args.piecewise_cuda_graph_tokens, - self.model_runner.server_args.piecewise_cuda_graph_compiler, - self.model_runner.server_args.enable_torch_compile_debug_mode, + capture_sizes=self.model_runner.server_args.piecewise_cuda_graph_tokens, + compiler=self.model_runner.server_args.piecewise_cuda_graph_compiler, + enable_debug_mode=self.model_runner.server_args.enable_torch_compile_debug_mode, ) if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): self.compile_config.add_split_op( "sglang.moe_forward_piecewise_cuda_graph_impl" ) + self.model_config = self.model_runner.model_config + self.sglang_config = SGLangConfig( + model_config=self.model_config, + device_config=DeviceConfig( + self.model_runner.device, self.model_runner.gpu_id + ), + compilation_config=self.compile_config, + ) self.quant_config = getattr(self.model_runner.model, "quant_config", None) # Batch sizes to capture @@ -223,7 +233,7 @@ def __init__(self, model_runner: ModelRunner): patched_model, fullgraph=True, dynamic_arg_dims=None, - compile_config=self.compile_config, + sglang_config=self.sglang_config, graph_pool=get_global_graph_memory_pool(), )