diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index fb0264a6ea9..279f9e27b64 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -30,6 +30,7 @@ ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket from sglang.utils import get_exception_traceback @@ -174,6 +175,10 @@ def launch_tensor_parallel_group( if not server_args.enable_dp_attention: logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + # Launch tensor parallel scheduler processes scheduler_pipe_readers = [] tp_size_per_node = server_args.tp_size // server_args.nnodes @@ -208,7 +213,8 @@ def launch_tensor_parallel_group( target=run_scheduler_process, args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), ) - proc.start() + with memory_saver_adapter.configure_subprocess(): + proc.start() self.scheduler_procs.append(proc) scheduler_pipe_readers.append(reader)