Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
get_bool_env_var,
is_cuda_alike,
is_npu,
is_shm_available,
supports_custom_op,
)

Expand Down Expand Up @@ -222,7 +223,7 @@ def __init__(
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None

self.local_size = int(os.environ.get("LOCAL_SIZE", "0"))
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
Expand Down Expand Up @@ -440,9 +441,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return input_

if input_.is_cpu:
import intel_extension_for_pytorch as ipex

ipex.distributed.all_reduce(input_, group=self.device_group)
if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
input_, torch.distributed.ReduceOp.SUM
)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_

if not supports_custom_op():
Expand Down Expand Up @@ -562,6 +566,16 @@ def all_gather(
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
return output_tensor

# All-gather.
self.all_gather_into_tensor(output_tensor, input_)
# Reshape
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ def init_torch_distributed(self):
), "init_cpu_threads_env failed since intel amx backend is not available"
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)

# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)

# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,3 +2477,12 @@ def get_cpu_ids_by_node():

# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids
def is_shm_available(dtype):
return cpu_has_amx_support() and dtype in [torch.bfloat16, torch.float]
def is_shm_available(dtype, world_size, local_size):
return (
cpu_has_amx_support()
and dtype in [torch.bfloat16, torch.float]
and world_size >= 1
and world_size == local_size
)
55 changes: 5 additions & 50 deletions sgl-kernel/csrc/cpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) {
}
}

void shm_allreduce(
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
void shm_allreduce(torch::Tensor& data, int64_t op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));

TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");

auto numel = data.numel();

int data_size = 0;
bool data_type_fallback = false;

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}

if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<torch::Tensor> tensors = {data};
process_group->allreduce(tensors)->wait();
} else {
all_reduce_outer_loop(data, numel, data_size);
}
int data_size = numel * data.element_size();
all_reduce_outer_loop(data, numel, data_size);

return;
}

torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) {
torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));

auto numel = data.numel();

int data_size = 0;
bool data_type_fallback = false;

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
int data_size = numel * data.element_size();
if (dim < 0) {
dim += data.dim();
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<std::vector<torch::Tensor>> output_tensors(1);
auto world_size = process_group->getSize();
for (int i = 0; i < world_size; i++) {
output_tensors[0].push_back(torch::empty_like(data));
}
std::vector<torch::Tensor> input_tensors = {data};
process_group->allgather(output_tensors, input_tensors)->wait();
return torch::cat(output_tensors[0], dim).contiguous();
}
std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
Expand Down
13 changes: 5 additions & 8 deletions sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
void initialize(int64_t size, int64_t rank);

// shared mmeory all_reduce
void shm_allreduce(
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
void shm_allreduce(at::Tensor& data, int64_t op);

// shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);

// rope
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
Expand Down Expand Up @@ -343,12 +342,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

// all reduce
m.def("initialize(int size, int rank) -> ()");
m.impl("initialize", torch::kCPU, &initialize);
m.def(
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);

// rope
Expand All @@ -363,6 +359,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
m.impl("init_cpu_threads_env", init_cpu_threads_env);
m.impl("initialize", &initialize);
}

REGISTER_EXTENSION(common_ops)
Loading