diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index cc2ba95a614..54f45c02b34 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -44,6 +44,7 @@ get_bool_env_var, is_cuda_alike, is_npu, + is_shm_available, supports_custom_op, ) @@ -222,7 +223,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None - + self.local_size = int(os.environ.get("LOCAL_SIZE", "0")) for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend @@ -440,9 +441,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if input_.is_cpu: - import intel_extension_for_pytorch as ipex - - ipex.distributed.all_reduce(input_, group=self.device_group) + if is_shm_available(input_.dtype, self.world_size, self.local_size): + torch.ops.sgl_kernel.shm_allreduce( + input_, torch.distributed.ReduceOp.SUM + ) + else: + torch.distributed.all_reduce(input_, group=self.device_group) return input_ if not supports_custom_op(): @@ -562,6 +566,16 @@ def all_gather( output_tensor = torch.empty( output_size, dtype=input_.dtype, device=input_.device ) + + if input_.is_cpu: + if is_shm_available(input_.dtype, self.world_size, self.local_size): + return torch.ops.sgl_kernel.shm_allgather(input_, dim) + else: + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + return output_tensor + # All-gather. self.all_gather_into_tensor(output_tensor, input_) # Reshape diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 999daa2abe5..3dde43c8c60 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -473,6 +473,10 @@ def init_torch_distributed(self): ), "init_cpu_threads_env failed since intel amx backend is not available" torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + # Set local size to hint SGLang to use shared memory based AllReduce + os.environ["LOCAL_SIZE"] = str(self.tp_size) + torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) + # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index eadead9239b..c15b6d7d156 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2477,3 +2477,12 @@ def get_cpu_ids_by_node(): # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] return cpu_ids +def is_shm_available(dtype): + return cpu_has_amx_support() and dtype in [torch.bfloat16, torch.float] +def is_shm_available(dtype, world_size, local_size): + return ( + cpu_has_amx_support() + and dtype in [torch.bfloat16, torch.float] + and world_size >= 1 + and world_size == local_size + ) diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp index 969a6bad4ed..9057a50f4a2 100644 --- a/sgl-kernel/csrc/cpu/interface.cpp +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) { } } -void shm_allreduce( - torch::Tensor& data, c10::intrusive_ptr process_group, c10::intrusive_ptr op) { +void shm_allreduce(torch::Tensor& data, int64_t op) { RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector({data})); TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported"); auto numel = data.numel(); - - int data_size = 0; - bool data_type_fallback = false; - - switch (data.scalar_type()) { - case c10::ScalarType::BFloat16: - data_size = numel * 2; - break; - case c10::ScalarType::Float: - data_size = numel * 4; - break; - default: - data_type_fallback = true; - } - - if (data_type_fallback || !all_ranks_local_p) { - // Fallback to torch distributed allreduce - std::vector tensors = {data}; - process_group->allreduce(tensors)->wait(); - } else { - all_reduce_outer_loop(data, numel, data_size); - } + int data_size = numel * data.element_size(); + all_reduce_outer_loop(data, numel, data_size); return; } -torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr process_group, int64_t dim) { +torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) { RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector({data})); auto numel = data.numel(); - - int data_size = 0; - bool data_type_fallback = false; - - switch (data.scalar_type()) { - case c10::ScalarType::BFloat16: - data_size = numel * 2; - break; - case c10::ScalarType::Float: - data_size = numel * 4; - break; - default: - data_type_fallback = true; - } + int data_size = numel * data.element_size(); if (dim < 0) { dim += data.dim(); } - if (data_type_fallback || !all_ranks_local_p) { - // Fallback to torch distributed allreduce - std::vector> output_tensors(1); - auto world_size = process_group->getSize(); - for (int i = 0; i < world_size; i++) { - output_tensors[0].push_back(torch::empty_like(data)); - } - std::vector input_tensors = {data}; - process_group->allgather(output_tensors, input_tensors)->wait(); - return torch::cat(output_tensors[0], dim).contiguous(); - } std::vector result_shape = data.sizes().vec(); result_shape[dim] *= world_size; torch::Tensor result_tensor = torch::empty(result_shape, data.options()); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 17e2f824c8f..44257dec5e0 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -212,11 +212,10 @@ std::tuple qkv_proj_with_rope_fused_weight( void initialize(int64_t size, int64_t rank); // shared mmeory all_reduce -void shm_allreduce( - at::Tensor& data, c10::intrusive_ptr process_group, c10::intrusive_ptr op); +void shm_allreduce(at::Tensor& data, int64_t op); // shared memory all_gather -at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int64_t dim); +at::Tensor shm_allgather(at::Tensor& data, int64_t dim); // rope std::tuple rotary_embedding_cpu( @@ -343,12 +342,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // all reduce m.def("initialize(int size, int rank) -> ()"); - m.impl("initialize", torch::kCPU, &initialize); - m.def( - "shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, " - "__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()"); + m.def("shm_allreduce(Tensor data, int reduce_op) -> ()"); m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); - m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor"); + m.def("shm_allgather(Tensor data, int dim) -> Tensor"); m.impl("shm_allgather", torch::kCPU, &shm_allgather); // rope @@ -363,6 +359,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { m.impl("init_cpu_threads_env", init_cpu_threads_env); + m.impl("initialize", &initialize); } REGISTER_EXTENSION(common_ops)