diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 24731d2b878..aa26d56d05c 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -42,8 +42,10 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, + get_int_env_var, is_cuda_alike, is_npu, + is_shm_available, supports_custom_op, ) @@ -222,6 +224,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None + self.local_size = get_int_env_var("LOCAL_SIZE", 0) for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -440,9 +443,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if input_.is_cpu: - import intel_extension_for_pytorch as ipex - - ipex.distributed.all_reduce(input_, group=self.device_group) + if is_shm_available(input_.dtype, self.world_size, self.local_size): + torch.ops.sgl_kernel.shm_allreduce( + input_, torch.distributed.ReduceOp.SUM + ) + else: + torch.distributed.all_reduce(input_, group=self.device_group) return input_ if not supports_custom_op(): @@ -570,6 +576,16 @@ def all_gather( output_tensor = torch.empty( output_size, dtype=input_.dtype, device=input_.device ) + + if input_.is_cpu: + if is_shm_available(input_.dtype, self.world_size, self.local_size): + return torch.ops.sgl_kernel.shm_allgather(input_, dim) + else: + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + return output_tensor + # All-gather. self.all_gather_into_tensor(output_tensor, input_) # Reshape diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c8fe4639354..43502447537 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -506,9 +506,13 @@ def init_torch_distributed(self): if _is_cpu_amx_available: # Bind OpenMP threads to CPU cores torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + + # Set local size to hint SGLang to use shared memory based AllReduce + os.environ["LOCAL_SIZE"] = str(self.tp_size) + torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) else: logger.warning( - "init_cpu_threads_env is skipped since intel amx backend is not available" + "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" ) # Only initialize the distributed environment on the target model worker. diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 67851c8c3c8..72cba9aac88 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node(): # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] return cpu_ids + + +def is_shm_available(dtype, world_size, local_size): + return ( + cpu_has_amx_support() + and dtype in [torch.bfloat16, torch.float] + and world_size >= 1 + and world_size == local_size + )