Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_int_env_var,
is_cuda_alike,
is_npu,
is_shm_available,
supports_custom_op,
)

Expand Down Expand Up @@ -222,6 +224,7 @@ def __init__(
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self.local_size = get_int_env_var("LOCAL_SIZE", 0)

for ranks in group_ranks:
device_group = torch.distributed.new_group(
Expand Down Expand Up @@ -440,9 +443,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return input_

if input_.is_cpu:
import intel_extension_for_pytorch as ipex

ipex.distributed.all_reduce(input_, group=self.device_group)
if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
input_, torch.distributed.ReduceOp.SUM
)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
Expand Down Expand Up @@ -570,6 +576,16 @@ def all_gather(
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
return output_tensor

# All-gather.
self.all_gather_into_tensor(output_tensor, input_)
# Reshape
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,13 @@ def init_torch_distributed(self):
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)

# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
else:
logger.warning(
"init_cpu_threads_env is skipped since intel amx backend is not available"
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
)

# Only initialize the distributed environment on the target model worker.
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node():

# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids


def is_shm_available(dtype, world_size, local_size):
return (
cpu_has_amx_support()
and dtype in [torch.bfloat16, torch.float]
and world_size >= 1
and world_size == local_size
)
Loading