Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions python/sglang/srt/compilation/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backends.py


import ast
Expand All @@ -19,8 +19,10 @@
from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
from sglang.srt.compilation.inductor_pass import InductorPass
from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager
from sglang.srt.compilation.sglang_config import SGLangConfig
from sglang.srt.utils.common import is_npu, rank0_log

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -103,21 +105,24 @@ def load(
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Optional[Callable]:
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
key = (runtime_shape, graph_index, self.compiler.name)
handle = self.cache.get(key, None)
if handle is None:
return None

compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, runtime_shape
)
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via "
"handle %s",
"Directly load the %s-th graph for dynamic shape from %s via handle %s",
graph_index,
self.compiler.name,
handle,
)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
"Directly load the %s-th graph for shape %s from %s via handle %s",
graph_index,
str(runtime_shape),
self.compiler.name,
Expand All @@ -130,6 +135,7 @@ def compile(
graph: fx.GraphModule,
example_inputs,
inductor_config: dict[str, Any],
compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
Expand All @@ -143,7 +149,29 @@ def compile(

compiled_graph = None

# TODO(Yuwei): support cache loading
# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
if compiled_graph is not None:
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s",
elapsed,
)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s",
str(runtime_shape),
elapsed,
)
return compiled_graph

# no compiler cached the graph, or the cache is disabled,
# we need to compile it
Expand Down Expand Up @@ -190,6 +218,7 @@ def compile(
if graph_index == num_graphs - 1:
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
else:
Expand Down Expand Up @@ -256,20 +285,27 @@ def split_graph(
return split_gm, outputs


# we share the global graph pool among all the backends
global_graph_pool = None

compilation_start_time = 0.0


class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.

NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""

def __init__(
self,
module: torch.fx.GraphModule,
compile_submod_names: list[str],
inductor_config: dict[str, Any],
graph_pool,
compile_config: CompilationConfig,
sglang_config: SGLangConfig,
sglang_backend: "SGLangBackend",
):
super().__init__(module)
Expand All @@ -281,8 +317,8 @@ def __init__(
self.sglang_backend = sglang_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
self.inductor_config = inductor_config
self.compile_config = compile_config
self.sglang_config = sglang_config
self.compilation_config = sglang_config.compilation_config

def run(self, *args):
fake_args = [
Expand Down Expand Up @@ -312,7 +348,8 @@ def call_module(
self.sglang_backend.compiler_manager.compile(
submod,
args,
self.inductor_config,
self.sglang_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
Expand All @@ -321,8 +358,8 @@ def call_module(

self.module.__dict__[target] = make_backend(
submod,
self.compile_config,
self.inductor_config,
self.compilation_config,
self.sglang_backend.inductor_config,
self.graph_pool,
index,
len(self.compile_submod_names),
Expand Down Expand Up @@ -355,7 +392,19 @@ def set_model_tag(tag: str):


class SGLangBackend:
"""The compilation backend for `torch.compile` with SGLang.
It is used for compilation mode of `CompilationMode.SGLANG_COMPILE`,
where we customize the compilation.

The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.

This backend also adds the PostGradPassManager to Inductor config,
which handles the post-grad passes.
"""

sglang_config: SGLangConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
Expand All @@ -372,7 +421,7 @@ class SGLangBackend:

def __init__(
self,
config: CompilationConfig,
sglang_config: SGLangConfig,
graph_pool: Any,
):
rank0_log(f"Initializing SGLangBackend")
Expand All @@ -383,15 +432,31 @@ def __init__(
self.sym_tensor_indices = []
self.input_buffers = []

self.compiler_manager = CompilerManager(config)
self.sglang_config = sglang_config
self.compilation_config = sglang_config.compilation_config

self.compiler_manager = CompilerManager(self.compilation_config)
self.inductor_config = {
"enable_auto_functionalized_v2": False,
}
self.compile_config = config

def configure_post_pass(self):
self.post_grad_pass_manager.configure()
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
config = self.compilation_config
self.post_grad_pass_manager.configure(self.sglang_config)

# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
# PassManager already added to config
pass
else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
rank0_log(f"SGLangBackend __call__")
Expand Down Expand Up @@ -423,7 +488,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self.split_gm, self.piecewise_graphs = split_graph(
graph,
self.compile_config.split_ops,
self.compilation_config.split_ops,
)
from torch._dynamo.utils import lazy_format_graph_code

Expand All @@ -443,9 +508,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
PiecewiseCompileInterpreter(
self.split_gm,
submod_names_to_compile,
self.inductor_config,
self.graph_pool,
self.compile_config,
self.sglang_config,
self,
).run(*example_inputs)

Expand Down
Loading
Loading