Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if CUDA endif
endif()

if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
# if ROCM endif
endif()

message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
Expand Down
86 changes: 86 additions & 0 deletions csrc/custom_quickreduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>

#ifdef USE_ROCM

#include "quickreduce/quick_reduce.h"

quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) {
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6)
throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
fptr->init(world_size, rank);
return (quickreduce::fptr_t)fptr;
}

void qr_destroy(quickreduce::fptr_t _fa) {
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}

torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
hipIpcMemHandle_t handle = fa->get_handle();
auto device_index = c10::cuda::current_device();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
return data_handle;
}

void qr_open_handles(quickreduce::fptr_t _fa,
const std::vector<torch::Tensor>& handles) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
std::vector<hipIpcMemHandle_t> ipc_handles;
ipc_handles.reserve(handles.size());
for (auto& handle : handles) {
// Ensure the tensor is on the same device as the current device.
hipIpcMemHandle_t ipc_handle;
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
ipc_handles.push_back(ipc_handle);
}
fa->open_ipc_handles(ipc_handles);
}

void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
torch::Tensor& out, int64_t quant_level) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();

TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_LE(out.numel(), quickreduce::DeviceComms::kMaxProblemSize);
if (out.scalar_type() == at::ScalarType::Half) {
fa->allreduce<half>(reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel(),
quant_level, stream);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
fa->allreduce<quickreduce::nv_bfloat16>(
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
out.numel(), quant_level, stream);
} else {
throw std::runtime_error(
"quick allreduce only supports float16 and bfloat16");
}
}

int64_t qr_max_size() {
return static_cast<int64_t>(quickreduce::DeviceComms::kMaxProblemSize);
}

#endif // USE_ROCM
10 changes: 10 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,13 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);

#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level);
int64_t qr_max_size();
#endif
Loading