diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 38f3479992c6..c2aa24b50f25 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -300,7 +300,6 @@ def call_module( runtime_shape=None, ) ) - self.module.__dict__[target] = CUDAPiecewiseBackend( submod, self.compile_config, @@ -395,8 +394,17 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: local_cache_dir, disable_cache=False, prefix="" ) compilation_counter.num_graphs_seen += 1 - - assert not self._called, "SGLangBackend can only be called once" + + # Print call stack trace + # import traceback + # print(f"[DEBUG] SGLangBackend __call__ invoked, backend_id={id(self)}") + # print("[DEBUG] Call stack:") + # for line in traceback.format_stack()[:-1]: # Exclude current frame + # print(line.rstrip()) + + # assert not self._called, "SGLangBackend can only be called once" + if (self._called): + return self.split_gm self.graph = graph self.configure_post_pass() diff --git a/python/sglang/srt/compilation/compile.py b/python/sglang/srt/compilation/compile.py index e0ec9a1f327e..53e139c1c083 100644 --- a/python/sglang/srt/compilation/compile.py +++ b/python/sglang/srt/compilation/compile.py @@ -136,12 +136,13 @@ def install_torch_compiled( dyn_map = dynamic_arg_dims or _infer_dynamic_arg_dims_from_annotations(unbound_fwd) + # Create the backend instance once and reuse it + backend_instance = None if backend_factory is None: from sglang.srt.compilation.backend import SGLangBackend - backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( - gm, ex - ) + backend_instance = SGLangBackend(compile_config, graph_pool) + backend_factory = lambda gm, ex: backend_instance(gm, ex) compiled_codes: list[type(original_code)] = [] state = {"compiled": False, "compiled_callable": None} diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index b96755d4fdaf..dc8649cc5e77 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -141,6 +141,7 @@ def __call__(self, *args) -> Any: return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] + # print(f"[DEBUG] __call__ graph {self.piecewise_compile_index}, shape={runtime_shape}, backend_id={id(self)}, entry_id={id(entry)}, warmup={entry.num_finished_warmup}, has_cudagraph={entry.cudagraph is not None}") if entry.runnable is None: entry.runnable = self.compiled_graph_for_general_shape diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 2322ebb22f32..d95b6b63d2ec 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -197,9 +197,6 @@ def __init__(self, model_runner: ModelRunner): graph_pool=get_global_graph_memory_pool(), ) - with set_compiled(True): - self.warmup_and_capture() - # Capture try: self.capture()