Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 5 additions & 50 deletions sgl-kernel/csrc/cpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) {
}
}

void shm_allreduce(
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
void shm_allreduce(torch::Tensor& data, int64_t op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));

TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");

auto numel = data.numel();

int data_size = 0;
bool data_type_fallback = false;

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}

if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<torch::Tensor> tensors = {data};
process_group->allreduce(tensors)->wait();
} else {
all_reduce_outer_loop(data, numel, data_size);
}
int data_size = numel * data.element_size();
all_reduce_outer_loop(data, numel, data_size);

return;
}

torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) {
torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));

auto numel = data.numel();

int data_size = 0;
bool data_type_fallback = false;

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
int data_size = numel * data.element_size();
if (dim < 0) {
dim += data.dim();
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<std::vector<torch::Tensor>> output_tensors(1);
auto world_size = process_group->getSize();
for (int i = 0; i < world_size; i++) {
output_tensors[0].push_back(torch::empty_like(data));
}
std::vector<torch::Tensor> input_tensors = {data};
process_group->allgather(output_tensors, input_tensors)->wait();
return torch::cat(output_tensors[0], dim).contiguous();
}
std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
Expand Down
20 changes: 12 additions & 8 deletions sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
void initialize(int64_t size, int64_t rank);

// shared mmeory all_reduce
void shm_allreduce(
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
void shm_allreduce(at::Tensor& data, int64_t op);

// shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);

// rope
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
Expand Down Expand Up @@ -340,19 +339,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

// all reduce
m.def("initialize(int size, int rank) -> ()");
m.impl("initialize", torch::kCPU, &initialize);
m.def(
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);

// rope
m.def(
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)");
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);

// CPU and memory binding
m.def("init_cpu_threads_env(str cpu_ids) -> str");
}

TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
m.impl("init_cpu_threads_env", init_cpu_threads_env);
m.impl("initialize", &initialize);
}

REGISTER_EXTENSION(common_ops)