From 9370db16fa904f1156749f80eb02518eff22469c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 23 May 2025 07:16:05 +0000 Subject: [PATCH 1/9] solve rebase conflict Signed-off-by: Haoyang Li --- CMakeLists.txt | 7 + csrc/ops.h | 11 + csrc/quick_all_reduce.cu | 275 ++++ csrc/quick_all_reduce.cuh | 1120 +++++++++++++++++ csrc/quick_all_reduce.h | 426 +++++++ csrc/torch_bindings.cpp | 16 + docs/usage/usage_stats.md | 3 +- tests/compile/test_full_graph.py | 1 + tests/distributed/test_quick_all_reduce.py | 135 ++ tests/tensorizer_loader/test_tensorizer.py | 4 + vllm/_custom_ops.py | 32 + vllm/config.py | 16 +- .../device_communicators/cuda_communicator.py | 26 +- .../device_communicators/quick_all_reduce.py | 215 ++++ vllm/distributed/parallel_state.py | 10 + vllm/engine/arg_utils.py | 5 + vllm/engine/llm_engine.py | 2 + vllm/entrypoints/llm.py | 4 + vllm/envs.py | 8 + vllm/platforms/cuda.py | 4 + vllm/platforms/interface.py | 7 + vllm/platforms/rocm.py | 7 + vllm/v1/utils.py | 2 + vllm/v1/worker/gpu_worker.py | 4 +- vllm/worker/worker.py | 4 +- 25 files changed, 2338 insertions(+), 6 deletions(-) create mode 100644 csrc/quick_all_reduce.cu create mode 100644 csrc/quick_all_reduce.cuh create mode 100644 csrc/quick_all_reduce.h create mode 100644 tests/distributed/test_quick_all_reduce.py create mode 100644 vllm/distributed/device_communicators/quick_all_reduce.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a1ed588749a..848a7841807b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -638,6 +638,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # must be rocm + list(APPEND VLLM_EXT_SRC + "csrc/quick_all_reduce.cu" + ) +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/csrc/ops.h b/csrc/ops.h index 7044b4588b81..b7b031ad8a07 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -345,3 +345,14 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t init_quick_ar(int64_t world_size, int64_t rank); +torch::Tensor qr_get_comm_handle(fptr_t _fa); +void qr_set_comm_handles(fptr_t _fa, + std::vector const& comm_handles); +void qr_all_reduce(fptr_t _fa, int64_t profile, torch::Tensor const& inp, + torch::Tensor& out); +void qr_destroy(fptr_t _fa); +void is_quickreduce_available(); +#endif \ No newline at end of file diff --git a/csrc/quick_all_reduce.cu b/csrc/quick_all_reduce.cu new file mode 100644 index 000000000000..76e52a19a3ef --- /dev/null +++ b/csrc/quick_all_reduce.cu @@ -0,0 +1,275 @@ +#include +#include +#include + +#include "quick_all_reduce.cuh" + +namespace quickreduce { + +// ============================================================ +// CONTEXT +// ============================================================ +void DeviceComms::init(int world_size, int rank) { + destroy(); + this->world_size = world_size; + this->rank = rank; + + // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. + long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); + long data_buffer_size = 2 * kMaxProblemSize; + long total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + hipMemset(dbuffer, 0, flags_buffer_size); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer); + + initialized = true; +} + +void DeviceComms::destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + hipIpcCloseMemHandle(dbuffer_list[i]); + } + } + + hipFree(dbuffer); + hipFree(dbuffer_list); + + initialized = false; + } +} + +void DeviceComms::open_ipc_handles( + std::vector const& ipc_handles) { + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess); + } else { + buffer_list[i] = dbuffer; + } + } + + hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), + hipMemcpyHostToDevice); +} + +// ============================================================ +// KERNEL +// ============================================================ +template +__global__ __quickreduce_launch_bounds__ static void allreduce_prototype( + T const* A, T* B, int N, int num_blocks, int world_size, int rank, + uint8_t** dbuffer_list, long data_offset, int flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKenel::run(A, B, N, block, num_blocks, world_size, rank, + dbuffer_list, data_offset, flag_color); + block += grid; + } +} + +// ============================================================ +// DISPATCH +// ============================================================ +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec<2, T>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec<4, T>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec<8, T>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + +template +void DeviceComms::allreduce(int profile, hipStream_t stream, T const* A, T* B, + int N) { + static_assert(sizeof(T) == 2, + "Template parameter T must be 16 bits (2 bytes) in size."); + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + long msg_size = N * sizeof(T); + int num_blocks = divceil(msg_size, kTileSize); + int grid = min(304 * 4, num_blocks); + + // ------------------------------------------------- + // All reduce dispatch. + QuickReduceProfile dprofile = static_cast(profile); + switch (dprofile) { + case QuickReduceProfile::TWOSHOT_FP8: + TWOSHOT_DISPATCH(TwoshotFP8LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q8: + TWOSHOT_DISPATCH(TwoshotQ8LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q6: + TWOSHOT_DISPATCH(TwoshotQ6LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q4: + TWOSHOT_DISPATCH(TwoshotQ4LineCodec) + break; + case QuickReduceProfile::ONESHOT_F16: + using AllReduceKernel = AllReduceOneshot; + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), + dim3(kBlock), 0, stream, A, B, N, num_blocks, + world_size, rank, dbuffer_list, data_offset, + flag_color); + break; + default: + TWOSHOT_DISPATCH(TwoshotF16LineCodec) + break; + } + +#endif + + // ------------------------------------------------- + // Rotate the flag color. + flag_color++; +} + +} // namespace quickreduce + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous(torch::Tensor const& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +fptr_t init_quick_ar(int64_t world_size, int64_t rank) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + quickreduce::DeviceComms* dev_comm = new quickreduce::DeviceComms(); + dev_comm->init(world_size, rank); + return reinterpret_cast(dev_comm); +} + +torch::Tensor qr_get_comm_handle(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + torch::Tensor tensor_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(tensor_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return tensor_handle; +} + +void qr_set_comm_handles(fptr_t _fa, + std::vector const& comm_handles) { + auto fa = reinterpret_cast(_fa); + auto world_size = comm_handles.size(); + std::vector ipc_handles(world_size); + + for (int i = 0; i < world_size; ++i) { + const auto& tensor = comm_handles[i]; + TORCH_CHECK(tensor.device().is_cpu(), "Comm handle tensor must be on CPU"); + TORCH_CHECK(tensor.scalar_type() == torch::kUInt8, + "Comm handle tensor must be of type uint8"); + TORCH_CHECK(tensor.numel() == sizeof(hipIpcMemHandle_t), + "Comm handle tensor must have ", sizeof(hipIpcMemHandle_t), + " elements"); + + std::memcpy(&(ipc_handles[i]), tensor.data_ptr(), + sizeof(hipIpcMemHandle_t)); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(fptr_t _fa, int64_t profile, torch::Tensor const& inp, + torch::Tensor& out) { + quickreduce::DeviceComms* fa = + reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); // hipStream_t + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(_is_weak_contiguous(out)); + TORCH_CHECK(_is_weak_contiguous(inp)); + + auto input_size = inp.numel() * inp.element_size(); + + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(profile, stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), inp.numel()); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + fa->allreduce( + profile, stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), inp.numel()); + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16 for now."); + } +} + +void qr_destroy(fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +void is_quickreduce_available() {}; \ No newline at end of file diff --git a/csrc/quick_all_reduce.cuh b/csrc/quick_all_reduce.cuh new file mode 100644 index 000000000000..5c994c5321ef --- /dev/null +++ b/csrc/quick_all_reduce.cuh @@ -0,0 +1,1120 @@ +#pragma once + +#include +#include "quick_all_reduce.h" + +namespace quickreduce { + +// ============================================================ +// Twoshot +// ============================================================ +// MARK: FP16 Line Codec +template +struct TwoshotF16LineCodec { + /* + Default FP16 line codec for Twoshot collectives. + No actual compression is involved. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of fp16x8_t (16B). + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileSize = 256 * kRankAtoms * sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + int const thread; + int const rank; + + __device_inline__ TwoshotF16LineCodec(int thread, int rank) + : thread(thread), rank(rank) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// MARK: FP8 Line Codec +template +struct TwoshotFP8LineCodec { + /* + FP8 Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled FP8 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a fp8x8_t (8B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 2176; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // FP8 Maximum value (on AMD Instinct MI300X - float8_e4m3fnuz) + static float constexpr kFP8Max = 240.0f; + static int constexpr kScaleFactor = + Quantfp8Const::kScaleFactor; // {1/240.0h, 1/240.0h} + static int constexpr kScaleEpsilon = + Quantfp8Const::kScaleEpsilon; // {1e-7, 1e-7} + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotFP8LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // abs(w) + int32x4_t w; + { + T const* x = reinterpret_cast(&atom); + T* y = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) { + y[i] = __habs(x[i]); + } + } + + // max(w) + int wmax; + { + int a, b; + int* dw = reinterpret_cast(&w); + a = pk_max(dw[0], dw[1]); + b = pk_max(dw[2], dw[3]); + wmax = pk_max(a, b); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = pk_max(wmax, x); + } + + // Share with the cohort + wmax = __shfl(wmax, group_leader); + } + + // Derive scales + int decoding_scale = pk_mul(wmax, kScaleFactor); + int encoding_scale = pk_add(decoding_scale, kScaleEpsilon); + encoding_scale = pk_hcp(encoding_scale); + + // Apply scales to get quantized values + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(atom[i], encoding_scale); + } + + // Convert to packed FP8 + fp32x8_t wf; + { + if constexpr (std::is_same::value) { + half2 const* x = reinterpret_cast(&w); + float2* y = reinterpret_cast(&wf); + for (int i = 0; i < 4; i++) { + y[i] = __half22float2(x[i]); + } + } else { + nv_bfloat162 const* x = reinterpret_cast(&w); + float2* y = reinterpret_cast(&wf); + for (int i = 0; i < 4; i++) { + y[i] = __bfloat1622float2(x[i]); + } + } + } + + int32x2_t qw; + qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[0], wf[1], qw[0], 0); + qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[2], wf[3], qw[0], 1); + qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[4], wf[5], qw[1], 0); + qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[6], wf[7], qw[1], 1); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack FP8 + int32x4_t w; + { + if constexpr (std::is_same::value) { + for (int i = 0; i < 2; i++) { + fp32x2_t wf0 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 0); + fp32x2_t wf1 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 1); + + asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" + : "=v"(w[i * 2 + 0]) + : "v"(wf0[0]), "v"(wf0[1])); + asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" + : "=v"(w[i * 2 + 1]) + : "v"(wf1[0]), "v"(wf1[1])); + } + } else { + nv_bfloat16* wbf = reinterpret_cast(&w); + for (int i = 0; i < 2; i++) { + fp32x2_t wf0_vec = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 0); + fp32x2_t wf1_vec = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 1); + wbf[i * 4 + 0] = __float2bfloat16(wf0_vec[0]); + wbf[i * 4 + 1] = __float2bfloat16(wf0_vec[1]); + wbf[i * 4 + 2] = __float2bfloat16(wf1_vec[0]); + wbf[i * 4 + 3] = __float2bfloat16(wf1_vec[1]); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Q4 Line Codec +template +struct TwoshotQ4LineCodec { + /* + Int4-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int4 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 1152; + static int constexpr kRankTileScaleOffset = 1024; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q4 configuration + static int constexpr kScaleFactor = + Quant4Const::kScaleFactor; // {-1/8.0h, -1/8.0h}, fp16x2_t + static int constexpr kScaleEpsilon = Quant4Const::kScaleEpsilon; + ; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = Quant4Const::kRangeMin; + ; // {-8, -8}, fp16x2_t + static int constexpr kRangeMax = Quant4Const::kRangeMax; + ; // {+7, +7}, fp16x2_t + static int constexpr kRangeBias = 0x00080008; // {+8, +8}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ4LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + a = pk_max(atom[0], atom[1]); + b = pk_max(atom[2], atom[3]); + wmax = pk_max(a, b); + + a = pk_min(atom[0], atom[1]); + b = pk_min(atom[2], atom[3]); + wmin = pk_min(a, b); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = pk_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = pk_min(wmin, y); + } + + wblockmax = pk_max_abs(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale = pk_mul(wblockmax, kScaleFactor); + int encoding_scale = pk_add(decoding_scale, kScaleEpsilon); + encoding_scale = pk_hcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(atom[i], encoding_scale); + w[i] = pk_max(w[i], kRangeMin); + w[i] = pk_min(w[i], kRangeMax); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q4), "v"(kHalf2_1032)); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = pk_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Q6 Line Codec +template +struct TwoshotQ6LineCodec { + /* + Int6-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int64 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 1664; + static int constexpr kRankTileQ2Offset = 1024; + static int constexpr kRankTileScaleOffset = 1536; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q6 configuration + static int constexpr kScaleFactor = + Quant6Const::kScaleFactor; // {-1/32.0h, -1/32.0h}, fp16x2_t + static int constexpr kScaleEpsilon = + Quant6Const::kScaleEpsilon; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = + Quant6Const::kRangeMin; // {-32, -32}, fp16x2_t + static int constexpr kRangeMax = + Quant6Const::kRangeMax; // {+31, +31}, fp16x2_t + static int constexpr kRangeBias = 0x00200020; // {+32, +32}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ6LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + a = pk_max(atom[0], atom[1]); + b = pk_max(atom[2], atom[3]); + wmax = pk_max(a, b); + + a = pk_min(atom[0], atom[1]); + b = pk_min(atom[2], atom[3]); + wmin = pk_min(a, b); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = pk_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = pk_min(wmin, y); + } + + wblockmax = pk_max_abs(wmax, wmin); + + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale = pk_mul(wblockmax, kScaleFactor); + int encoding_scale = pk_add(decoding_scale, kScaleEpsilon); + encoding_scale = pk_hcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(atom[i], encoding_scale); + w[i] = pk_max(w[i], kRangeMin); + w[i] = pk_min(w[i], kRangeMax); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = pk_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Q8 Line Codec +template +struct TwoshotQ8LineCodec { + /* + Int8-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int8 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int8x8_t (8B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 2176; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q8 configuration + static int constexpr kScaleFactor = + Quant8Const::kScaleFactor; // {-1/128.0h, -1/128.0h}, fp16x2_t + static int constexpr kScaleEpsilon = + Quant8Const::kScaleEpsilon; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = + Quant8Const::kRangeMin; // {-128, -128}, fp16x2_t + static int constexpr kRangeMax = + Quant8Const::kRangeMax; // {+127, +127}, fp16x2_t + static constexpr int kRangeBias = 0x00800080; + // {+128, +128}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ8LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + a = pk_max(atom[0], atom[1]); + b = pk_max(atom[2], atom[3]); + wmax = pk_max(a, b); + + a = pk_min(atom[0], atom[1]); + b = pk_min(atom[2], atom[3]); + wmin = pk_min(a, b); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = pk_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = pk_min(wmin, y); + } + + wblockmax = pk_max_abs(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale = pk_mul(wblockmax, kScaleFactor); + int encoding_scale = pk_add(decoding_scale, kScaleEpsilon); + encoding_scale = pk_hcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(atom[i], encoding_scale); + w[i] = pk_max(w[i], kRangeMin); + w[i] = pk_min(w[i], kRangeMax); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); // shared + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1152 = + 0xE480E480; // {-1152.0, -1152.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q8), "v"(kHalf2_1152)); + + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = pk_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = pk_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Twoshot All Reduce +template +struct AllReduceTwoshot { + // Fixed magic implementation. + // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of + // work. + static int constexpr kAtoms = 8; + + // Size and atom stride of source/destination data that the workgroup will + // process. + static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); + static int constexpr kAtomStride = 256; + + static int constexpr kWorldSize = LineCodec::kWorldSize; + + __device__ static void run( + T const* __restrict__ A, // input + T* __restrict__ B, // output + int const N, // number of elements + int const block, // block index + int const num_blocks, // number of blocks + int const world_size, // unused - only kept around for API consistency + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + int flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + + LineCodec codec(thread, rank); + + // -------------------------------------------------------- + // Read A into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + long comm_data0_offset = data_offset + block * LineCodec::kTileSize; + long comm_data1_offset = + num_blocks * LineCodec::kTileSize + comm_data0_offset; + + long comm_flags0_offset = block * (kWorldSize * sizeof(int)); + long comm_flags1_offset = + num_blocks * (kWorldSize * sizeof(int)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + comm_data0_offset + rank * LineCodec::kRankTileSize); + codec.send(send_buffer, &tA[r * LineCodec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + int* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[LineCodec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data0_offset); + int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != + flag_color) { + } + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < LineCodec::kRankAtoms; i++) { + int32x4_t& tA_fragment = tA[i]; + int32x4_t& tR_fragment = tR[i]; + tR_fragment[0] = pk_add(tR_fragment[0], tA_fragment[0]); + tR_fragment[1] = pk_add(tR_fragment[1], tA_fragment[1]); + tR_fragment[2] = pk_add(tR_fragment[2], tA_fragment[2]); + tR_fragment[3] = pk_add(tR_fragment[3], tA_fragment[3]); + } + } + } + + // -------------------------------------------------------- + // Phase-2: Write the reduced segment to every other rank + // This is basically an all-gather. + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + comm_data1_offset + rank * LineCodec::kRankTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + int* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + + // -------------------------------------------------------- + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data1_offset); + int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != + flag_color) { + } + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * LineCodec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to B. + BufferResource dst_buffer(B, N * sizeof(T)); + int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +// ============================================================ +// Oneshot +// ============================================================ +// MARK: Oneshot All Reduce +template +struct AllReduceOneshot { + // Fixed magic implementation. + // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of + // work. + static int constexpr kAtoms = 8; + + // Size and atom stride of data that the workgroup will process. + static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); + static int constexpr kAtomStride = 256; + + __device__ static void run( + T const* __restrict__ A, // input + T* __restrict__ B, // output + int const N, // number of elements + int const block, // this block's index + int const num_blocks, // total number of blocks + int const world_size, // total number of ranks + int const rank, // this rank's index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + int flag_color // Flag color for the network barrier + ) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + + long data_stride = num_blocks * kTileSize; + long flags_stride = num_blocks * sizeof(int); + + uint8_t* rank_buffer = buffer_list[rank]; + + // -------------------------------------------------------- + // Read A into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Write rank data into this rank segment of every rank's communication + // buffer. + long comm_data_offset = + data_offset + rank * data_stride + block * kTileSize; + long comm_flags_offset = rank * flags_stride + block * sizeof(int); + + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color - 1) { + } + } + __syncthreads(); + + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data_offset); + for (int i = 0; i < kAtoms; i++) { + __builtin_nontemporal_store(tA[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + // Inform the other ranks that th data has been posted. + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + + // -------------------------------------------------------- + // Read and reduce the data from this rank's communication buffer. + int32x4_t tB[kAtoms]; + + { + int r = 0; + + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); + + for (int i = 0; i < kAtoms; i++) { + tB[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } + } + + for (int r = 1; r < world_size; r++) { + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } + + // Reduce. + for (int i = 0; i < kAtoms; i++) { + int32x4_t& tA_fragment = tA[i]; + int32x4_t& tB_fragment = tB[i]; + tB_fragment[0] = pk_add(tB_fragment[0], tA_fragment[0]); + tB_fragment[1] = pk_add(tB_fragment[1], tA_fragment[1]); + tB_fragment[2] = pk_add(tB_fragment[2], tA_fragment[2]); + tB_fragment[3] = pk_add(tB_fragment[3], tA_fragment[3]); + } + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELAXED); + } + + // -------------------------------------------------------- + // Write the result to B. + BufferResource dst_buffer(B, N * sizeof(T)); + int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tB[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quick_all_reduce.h b/csrc/quick_all_reduce.h new file mode 100644 index 000000000000..a55b55f511fa --- /dev/null +++ b/csrc/quick_all_reduce.h @@ -0,0 +1,426 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +#define __device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) + // CDNA3: Scope bits sc0, sc1 + #define MUBUF_ACQUIRE 16 + #define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) + // CDNA1 and CDNA2 - glc bit + #define MUBUF_ACQUIRE 1 + #define MUBUF_RELEASE 0 +#endif + +namespace quickreduce { + +// Vector types +using int8x8_t = __attribute__((__vector_size__(8 * sizeof(int8_t)))) int8_t; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; +using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + +using fp8_t = uint8_t; +using fp8x8_t = __attribute__((__vector_size__(8 * sizeof(uint8_t)))) uint8_t; + +using fp16x4_t = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16; +using fp16x8_t = __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16; +using fp16x16_t = __attribute__((__vector_size__(16 * sizeof(__fp16)))) __fp16; + +using fp32x2_t = __attribute__((__vector_size__(2 * sizeof(float)))) float; +using fp32x4_t = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using fp32x8_t = __attribute__((__vector_size__(8 * sizeof(float)))) float; +using fp32x16_t = __attribute__((__vector_size__(16 * sizeof(float)))) float; + +// Standard CDNA wavefront size. +static int constexpr kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlock = {64, 4, 1}; + +// Methods +__device_inline__ __host__ int divceil(int x, int y) { + return ((x + y - 1) / y); +} + +__device_inline__ __host__ constexpr int divceil_constexpr(int const x, + int const y) { + return ((x + y - 1) / y); +} + +/* +=============================================================== +Desc: + Utility container to describe the Buffer Resource used in VMEM operations. + +Operation: + BufferResource can be initialized to tensor base address and range/size (in +bytes). The range is used for OOB checks. For example the range for a MxK +dtype=fp16 tensor would have a range of [M * K * sizeof(half)]. + + The last dword of the buffer resource description is to a default config +with DFMT=32b. + + Instructions that used the buffer resource (buffer_load/store_dwordx4) wait +on the `vmcnt`. +*/ + +union BufferResource { + __device_inline__ constexpr BufferResource() : config(0x00020000U) {} + + __device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__device_inline__ static void buffer_store_dwordx4( + int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__device_inline__ static void set_fp16_ovfl(bool const value) { + // short size = 0b00001; // Specifies the bit size to modify + // const short offset = 0b10111; // Corrected offset to 23, which is the bit + // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for + // MODE const short simm16 = (size << 11) | (offset << 6) | hwRegId; simm16 = + // 0xdc1 + +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +enum struct QuickReduceProfile { + ONESHOT_F16 = 0, + TWOSHOT_F16 = 1, + TWOSHOT_FP8 = 2, + TWOSHOT_Q8 = 3, + TWOSHOT_Q6 = 4, + TWOSHOT_Q4 = 5, +}; + +/* +=============================================================== +Desc: + Device Comms Handle +*/ +struct DeviceComms { + // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) + static long constexpr kTileSize = 256 * 16 * 8; + + // Max problem size is 512MB (in bytes) + static long constexpr kMaxProblemSize = 536870912; + static long constexpr kMaxTiles = kMaxProblemSize / kTileSize; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + int flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + + long data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { destroy(); } + + void init(int world_size, int rank); + int get_world_size() { return world_size; } + int get_rank() { return rank; } + bool status() { return initialized; } + void destroy(); + + hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } + void open_ipc_handles(std::vector const& ipc_handles); + template + void allreduce(int profile, hipStream_t stream, T const* A, T* B, int N); + torch::Tensor qr_get_comm_handle(); +}; + +// Function Template for two dtypes +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +// packed add +template +__device_inline__ int pk_add(int a, int b); + +template <> +__device_inline__ int pk_add(int a, int b) { + int res; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(res) : "v"(a), "v"(b)); + return res; +} + +template <> +__device_inline__ int pk_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +// packed max +template +__device_inline__ int pk_max(int a, int b); + +template <> +__device_inline__ int pk_max(int a, int b) { + int res; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(res) : "v"(a), "v"(b)); + return res; +} + +template <> +__device_inline__ int pk_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +// packed min +template +__device_inline__ int pk_min(int a, int b); + +template <> +__device_inline__ int pk_min(int a, int b) { + int res; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(res) : "v"(a), "v"(b)); + return res; +} + +template <> +__device_inline__ int pk_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +// pk_max_abs +template +__device_inline__ int pk_max_abs(int a, int b); + +template <> +__device_inline__ int pk_max_abs(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = + __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) + ? wmaxh2.x + : wminh2.x; + wblockmaxh2.y = + __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) + ? wmaxh2.y + : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__device_inline__ int pk_max_abs(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = + __bfloat162float(__habs(A.bf2.x)) > __bfloat162float(__habs(B.bf2.x)) + ? A.bf2.x + : B.bf2.x; + R.bf2.y = + __bfloat162float(__habs(A.bf2.y)) > __bfloat162float(__habs(B.bf2.y)) + ? A.bf2.y + : B.bf2.y; + return R.i; +} + +// pk_mul +template +__device_inline__ int pk_mul(int a, int b); + +template <> +__device_inline__ int pk_mul(int a, int b) { + int res; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res) : "v"(a), "v"(b)); + return res; +} + +template <> +__device_inline__ int pk_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +// pk_hcp +template +__device_inline__ int pk_hcp(int a); + +template <> +__device_inline__ int pk_hcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__device_inline__ int pk_hcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__device_inline__ float T2float(half a) { return __half2float(a); } + +__device_inline__ float T2float(nv_bfloat16 a) { return __bfloat162float(a); } + +// const Q8 +template +struct Quant8Const; + +template <> +struct Quant8Const { + static constexpr int kScaleFactor = + 0xA000A000; // {-1/128.0h, -1/128.0h}, fp16x2_t + static constexpr int kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static constexpr int kRangeMin = 0xD800D800; // {-128, -128}, fp16x2_t + static constexpr int kRangeMax = 0x57F057F0; // {+127, +127}, fp16x2_t +}; + +template <> +struct Quant8Const { + static constexpr int kScaleFactor = + 0xBC00BC00; // {-1/128.0h, -1/128.0h}, fp16x2_t + static constexpr int kScaleEpsilon = 0x33D733D7; // {1e-7, 1e-7}, fp16x2_t + static constexpr int kRangeMin = 0xC300C300; // {-128, -128}, fp16x2_t + static constexpr int kRangeMax = 0x42FE42FE; // {+127, +127}, fp16x2_t +}; +// const Q6 +template +struct Quant6Const; + +template <> +struct Quant6Const { + static int constexpr kScaleFactor = + 0xA800A800; // {-1/32.0h, -1/32.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xD000D000; // {-32, -32}, fp16x2_t + static int constexpr kRangeMax = 0x4FC04FC0; // {+31, +31}, fp16x2_t +}; + +template <> +struct Quant6Const { + static int constexpr kScaleFactor = + 0xBD00BD00; // {-1/32.0h, -1/32.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x33D733D7; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xC200C200; // {-32, -32}, fp16x2_t + static int constexpr kRangeMax = 0x41F841F8; // {+31, +31}, fp16x2_t +}; + +// const Q4 +template +struct Quant4Const; + +template <> +struct Quant4Const { + static int constexpr kScaleFactor = + 0xB000B000; // {-1/8.0h, -1/8.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xC800C800; // {-8, -8}, fp16x2_t + static int constexpr kRangeMax = 0x47004700; // {+7, +7}, fp16x2_t +}; + +template <> +struct Quant4Const { + static int constexpr kScaleFactor = + 0xBE00BE00; // {-1/8.0h, -1/8.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x33D733D7; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xC100C100; // {-8, -8}, fp16x2_t + static int constexpr kRangeMax = 0x40E040E0; // {+7, +7}, fp16x2_t +}; + +// const fp8 +template +struct Quantfp8Const; + +template <> +struct Quantfp8Const { + static int constexpr kScaleFactor = 0x1C441C44; // {1/240.0h, 1/240.0h} + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7} +}; + +template <> +struct Quantfp8Const { + static int constexpr kScaleFactor = 0x3B883B88; // {1/240.0h, 1/240.0h} bf + static int constexpr kScaleEpsilon = 0x33D733D7; // {1e-7, 1e-7} bf +}; +} // namespace quickreduce + +// /* +// =============================================================== +// API +// */ +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); +fptr_t init_quick_ar(int64_t world_size, int64_t rank); +torch::Tensor qr_get_comm_handle(fptr_t _fa); +void qr_set_comm_handles(fptr_t _fa, + std::vector const& comm_handles); +void qr_all_reduce(fptr_t _fa, int64_t profile, torch::Tensor const& inp, + torch::Tensor& out); +void qr_destroy(fptr_t _fa); +void is_quickreduce_available(); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4eda1aaccc6b..cb97ed3612c7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -703,4 +703,20 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("free_shared_buffer", &free_shared_buffer); } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _quick_ar), quick_ar) { + // quick all reduce kernels + quick_ar.def("init_quick_ar(int world_size, int rank) -> int"); + quick_ar.def("qr_get_comm_handle(int _fa) -> Tensor"); + quick_ar.def("qr_set_comm_handles(int _fa, Tensor[] handles) -> ()"); + quick_ar.def( + "qr_all_reduce(int _fa, int profile, Tensor inp, Tensor out) -> ()"); + quick_ar.def("qr_destroy(int _fa) -> ()"); + quick_ar.def("is_quickreduce_available() -> ()"); + quick_ar.impl("init_quick_ar", &init_quick_ar); + quick_ar.impl("qr_get_comm_handle", &qr_get_comm_handle); + quick_ar.impl("qr_set_comm_handles", &qr_set_comm_handles); + quick_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + quick_ar.impl("qr_destroy", &qr_destroy); + quick_ar.impl("is_quickreduce_available", &is_quickreduce_available); +} REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index 750cba7ed9ce..eaf586efe0ac 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -37,7 +37,8 @@ Here is an example as of v0.4.0: "enable_lora": false, "enable_prefix_caching": false, "enforce_eager": false, - "disable_custom_all_reduce": true + "disable_custom_all_reduce": true, + "disable_quick_all_reduce": true, } ``` diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 397517b8665b..2dcc99143d11 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -136,6 +136,7 @@ def run_model(compile_config: Union[int, CompilationConfig], model: str, enforce_eager=True, tensor_parallel_size=1, disable_custom_all_reduce=True, + disable_quick_all_reduce=True, compilation_config=compile_config, **model_kwargs, ) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py new file mode 100644 index 000000000000..9e92a8f2bbd6 --- /dev/null +++ b/tests/distributed/test_quick_all_reduce.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import random + +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, graph_capture) + +from ..utils import init_test_distributed_environment, multi_process_parallel + +random.seed(42) +test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +# Only enable QuickReduce +@ray.remote(num_gpus=1, max_calls=1) +def graph_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + os.environ["VLLM_QUICK_ALLREDUCE"] = "1" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + + for sz in test_sizes: + for dtype in [torch.float16, torch.bfloat16]: + with graph_capture(device=device) as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + +@ray.remote(num_gpus=1, max_calls=1) +def eager_quick_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + os.environ["VLLM_QUICK_ALLREDUCE"] = "1" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + sz = 1024 * 1024 * 16 + fa = get_tp_group().device_communicator.qr_comm + inp = torch.ones(sz, dtype=torch.float16, device=device) + out = inp + for _ in range(num_communication): + out = fa.all_reduce(out) + torch.testing.assert_close(out, inp * (tp_size**num_communication)) + + inp = torch.ones(sz * 2, dtype=torch.bfloat16, device=device) + out = inp + for _ in range(num_communication): + out = fa.all_reduce(out) + torch.testing.assert_close(out, inp * (tp_size**num_communication)) + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.parametrize("test_target", + [eager_quick_allreduce, graph_allreduce]) +def test_quick_reduce_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, + test_target) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index b6286e148397..f6d8156a02d8 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -201,6 +201,7 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): ), tensor_parallel_size=2, disable_custom_all_reduce=True, + disable_quick_all_reduce=True, ) except RuntimeError: out, err = capfd.readouterr() @@ -219,6 +220,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( with vllm_runner( model_ref, disable_custom_all_reduce=True, + disable_quick_all_reduce=True, enforce_eager=True, ) as base_model: outputs = base_model.generate(prompts, sampling_params) @@ -237,6 +239,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( model=model_ref, tensor_parallel_size=2, disable_custom_all_reduce=True, + disable_quick_all_reduce=True, enforce_eager=True, ), tensorizer_config=tensorizer_config, @@ -249,6 +252,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( tensor_parallel_size=2, load_format="tensorizer", disable_custom_all_reduce=True, + disable_quick_all_reduce=True, enforce_eager=True, model_loader_extra_config=tensorizer_config) as loaded_vllm_model: deserialized_outputs = loaded_vllm_model.generate( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3c8e6b95ce76..3b87a8f8349d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1669,6 +1669,38 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +# quick ar +def init_quick_ar(world_size: int, rank: int) -> int: + """Initialize the QuickReduce environment and return a Device Comms Handle api.""" + return torch.ops._C_quick_ar.init_quick_ar(world_size, rank) + + +def qr_get_comm_handle(fa: int) -> torch.Tensor: + """Return a Tensor handle""" + return torch.ops._C_quick_ar.qr_get_comm_handle(fa) + + +def qr_set_comm_handles(fa: int, handles: list[torch.Tensor]) -> None: + """Set the communication handle list.""" + torch.ops._C_quick_ar.qr_set_comm_handles(fa, handles) + + +def qr_all_reduce(fa: int, profile: int, inp: torch.Tensor, + out: torch.Tensor) -> None: + """Perform all-reduce across devices with optional profile.""" + torch.ops._C_quick_ar.qr_all_reduce(fa, profile, inp, out) + + +def qr_destroy(fa: int) -> None: + """Clean up and destroy the Device Comms Handle.""" + torch.ops._C_quick_ar.qr_destroy(fa) + + +def is_quickreduce_available() -> None: + """Only used to test whether module was properly imported.""" + torch.ops._C_quick_ar.is_quickreduce_available() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/config.py b/vllm/config.py index c0671d2524ec..ddde11844cb1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1702,6 +1702,12 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + disable_quick_all_reduce: bool = False + """Enable quick_allreduce to replace custom_allduce or nccl. + Supports fp8, Q8, Q6, Q4 quantization of bf16 and fp16. + Refer to envs.VLLM_QUICK_ALLREDUCE to control the quantization level. + Only supported on AMD. + """ tokenizer_pool_config: Optional[TokenizerPoolConfig] = None """This parameter is deprecated and will be removed in a future release. @@ -1903,6 +1909,11 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported on current platform.") + if not current_platform.use_quick_allreduce(): + self.disable_quick_all_reduce = True + logger.info( + "Disabled the quick all-reduce kernel because it is not " + "supported on current platform.") if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") @@ -2678,6 +2689,8 @@ def create_draft_parallel_config( max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. disable_custom_all_reduce, + disable_quick_all_reduce=target_parallel_config. + disable_quick_all_reduce, ray_workers_use_nsight=target_parallel_config. ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, @@ -4479,6 +4492,7 @@ def __str__(self): f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"disable_quick_all_reduce={self.parallel_config.disable_quick_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, " @@ -4544,7 +4558,7 @@ def get_current_vllm_config() -> VllmConfig: # in ci, usually when we test custom ops/modules directly, # we don't set the vllm config. In that case, we set a default # config. - logger.warning("Current vLLM config is not set.") + # logger.warning("Current vLLM config is not set.") from vllm.config import VllmConfig return VllmConfig() return _current_vllm_config diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index a05a13f51d4b..ff64a3a402aa 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -24,10 +25,12 @@ def __init__(self, if "tp" not in unique_name: # only tp uses custom allreduce use_custom_allreduce = False + use_quick_allreduce = False else: from vllm.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE) + _ENABLE_CUSTOM_ALL_REDUCE, _ENABLE_QUICK_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_quick_allreduce = _ENABLE_QUICK_ALL_REDUCE # ep does not use pynccl use_pynccl = "ep" not in unique_name @@ -40,6 +43,8 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllreduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -55,7 +60,14 @@ def __init__(self, group=self.cpu_group, device=self.device, ) - + self.qr_comm: Optional[QuickAllreduce] = None + if use_quick_allreduce and self.world_size > 1 and \ + current_platform.is_rocm(): + # Initialize a custom fast all-reduce implementation. + self.qr_comm = QuickAllreduce( + group=self.cpu_group, + device=self.device, + ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -69,9 +81,17 @@ def __init__(self, else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") + def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. + # if rocm, try quick allreduce first, then custom ar and pynccl. + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_ar(input_): + out = qr_comm.quick_all_reduce(input_) + assert out is not None + return out ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): @@ -149,6 +169,8 @@ def destroy(self): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None + if self.qr_comm is not None: + self.qr_comm = None if self.all2all_manager is not None: self.all2all_manager.destroy() self.all2all_manager = None diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 000000000000..98820151bd1e --- /dev/null +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +try: + ops.is_quickreduce_available() + quick_ar = True +except Exception: + # For CPUs + quick_ar = False + +logger = init_logger(__name__) + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +''' +quantization level & int +close qr = 0 +TwoShotF16 = 1 +TwoShotFP8 = 2 +TwoShotQ8 = 3 +TwoShotQ6 = 4 +TwoShotQ4 = 5 +OneShotQ4 = 6 +''' + + +class QuickAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, + torch.bfloat16] # TODO: support torch.float32 + _SUPPORTED_LEVEL = [0, 1, 2, 3, 4, 5] + + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=512 * 1024 * 1024, + min_size=32 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the QuickAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + assert \ + envs.VLLM_QUICK_ALLREDUCE in QuickAllreduce._SUPPORTED_LEVEL, ( + "quick allreduce level must be in [0, 1, 2, 3, 4, 5], " + f"but got {envs.VLLM_QUICK_ALLREDUCE}" + ) + + if not quick_ar: + # disable because of missing quick allreduce library + # e.g. in a non-GPU environment + logger.info("Quick allreduce is disabled because " + "of missing quick allreduce library") + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "QuickAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize quick allreduce for multi-node case. + logger.warning( + "Quick allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + self.rank = rank + world_size = dist.get_world_size(group=self.group) + + if world_size not in QuickAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Quick allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_quick_all_reduce=0 explicitly.", + world_size, str(QuickAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where quick allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if not fully_connected: + logger.warning( + "Quick allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_quick_all_reduce=0 explicitly.") + return + + self.max_size = max_size + self.min_size = min_size + self._ptr = ops.init_quick_ar(world_size, rank) + my_handle = ops.qr_get_comm_handle(self._ptr) + + all_handles = [[None] for _ in range(world_size)] + all_handles[rank][0] = my_handle + + for src in range(world_size): + dist.broadcast_object_list(all_handles[src], src=src) + comm_handles = [h[0] for h in all_handles] + ops.qr_set_comm_handles(self._ptr, comm_handles) + self.disabled = False + + def should_quick_ar(self, inp: torch.Tensor): + ''' + Check if quickreduce is available + ''' + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # quick allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + if inp.dtype in QuickAllreduce._SUPPORTED_DTYPES: + return inp_size < self.max_size # and inp_size > self.min_size + return False + + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = torch.empty_like(inp) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + # print("qr") + ops.qr_all_reduce(self._ptr, envs.VLLM_QUICK_ALLREDUCE, inp, out) + return out + + def quick_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When quick allreduce is disabled, this will be None. + if self.disabled or not self.should_quick_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since quick + # allreduce is out-of-place. + return torch.empty_like(input) + else: + return self.all_reduce(input, registered=False) + + def close(self): + '''del self._ptr and del buffer''' + if not self.disabled and self._ptr: + ops.qr_destroy(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b674d05a7771..15cc90e9228d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -900,6 +900,7 @@ def graph_capture(device: torch.device): logger = init_logger(__name__) _ENABLE_CUSTOM_ALL_REDUCE = True +_ENABLE_QUICK_ALL_REDUCE = True def set_custom_all_reduce(enable: bool): @@ -907,6 +908,15 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def set_quick_all_reduce(enable: bool): + ''' + qr brings acceleration through quantization, + but may reduce accuracy. + ''' + global _ENABLE_QUICK_ALL_REDUCE + _ENABLE_QUICK_ALL_REDUCE = enable + + def init_distributed_environment( world_size: int = -1, rank: int = -1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b90880167dc..1e35633f1574 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -324,6 +324,7 @@ class EngineArgs: enforce_eager: bool = ModelConfig.enforce_eager max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce + disable_quick_all_reduce: bool = ParallelConfig.disable_quick_all_reduce # The following three fields are deprecated and will be removed in a future # release. Setting them will have no effect. Please remove them from your # configurations. @@ -630,6 +631,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--disable-custom-all-reduce", **parallel_kwargs["disable_custom_all_reduce"]) + parallel_group.add_argument( + "--disable-quick-all-reduce", + **parallel_kwargs["disable_quick_all_reduce"]) parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) parallel_group.add_argument("--worker-extension-cls", @@ -1070,6 +1074,7 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, + disable_quick_all_reduce=self.disable_quick_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5ca3ebe91d12..0fd7d7ca3f13 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -302,6 +302,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.model_config.enforce_eager, "disable_custom_all_reduce": self.parallel_config.disable_custom_all_reduce, + "disable_quick_all_reduce": + self.parallel_config.disable_quick_all_reduce, }) self.cached_scheduler_outputs = [ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f818e1737975..0185156af497 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -118,6 +118,8 @@ class LLM: back to the eager mode. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. + disable_quick_all_reduce: See + [ParallelConfig][vllm.config.ParallelConfig]. disable_async_output_proc: Disable async output processing. This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files @@ -182,6 +184,7 @@ def __init__( enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + disable_quick_all_reduce: bool = False, disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, @@ -237,6 +240,7 @@ def __init__( enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + disable_quick_all_reduce=disable_quick_all_reduce, disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, diff --git a/vllm/envs.py b/vllm/envs.py index b007bf8c59b7..cd88e2c79116 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -118,6 +118,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 + VLLM_QUICK_ALLREDUCE: int = 1 def get_default_cache_root(): @@ -577,6 +578,13 @@ def get_vllm_port() -> Optional[int]: "VLLM_SKIP_P2P_CHECK": lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", + # use quick allreduce or not. 0 for 1stage with no quant, + # 1 for 2stage f16, 2 for 2stage fp8, 3 for 2stage Q8, + # 4 for 2stage Q6, 5 for 2stage Q4. + # limit this value to less than or equal to 5 at the time of use. + "VLLM_QUICK_ALLREDUCE": + lambda: int(os.getenv("VLLM_QUICK_ALLREDUCE", "1")), + # List of quantization kernels that should be disabled, used for testing # and performance comparisons. Currently only affects MPLinearKernel # selection diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8bb3dfe7457a..f3e782704f14 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -312,6 +312,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool: def use_custom_allreduce(cls) -> bool: return True + @classmethod + def use_quick_allreduce(cls) -> bool: + return False + @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 646faa944565..08f96e40ab82 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -453,6 +453,13 @@ def use_custom_allreduce(cls) -> bool: """ return False + @classmethod + def use_quick_allreduce(cls) -> bool: + """ + Returns if quick allreduce is supported on the current platform + """ + return False + @classmethod def validate_request( cls, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e1dcd9870b6c..ce7074ce8b04 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -374,6 +374,13 @@ def use_custom_allreduce(cls) -> bool: supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def use_quick_allreduce(cls) -> bool: + # We only enable quick allreduce for MI300 series + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + supported_archs = ['gfx94', 'gfx95'] + return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83cc..a5cba05e80a8 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -291,4 +291,6 @@ def report_usage_stats( vllm_config.model_config.enforce_eager, "disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce, + "disable_quick_all_reduce": + vllm_config.parallel_config.disable_quick_all_reduce, }) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index bce5cbb5f9d0..8e15f9f062a4 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -13,7 +13,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce) + set_custom_all_reduce, set_quick_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -344,6 +344,8 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + set_quick_all_reduce(parallel_config.disable_quick_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6e45b8423e5e..50bc07eeedd6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -12,7 +12,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce) + set_custom_all_reduce, set_quick_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -526,6 +526,8 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + set_quick_all_reduce(not parallel_config.disable_quick_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, From e40020133120eea5808d7f6f9a02c578e0a26b6c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 22 May 2025 16:40:15 +0000 Subject: [PATCH 2/9] add quickreduce Signed-off-by: Haoyang Li --- vllm/distributed/device_communicators/quick_all_reduce.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 98820151bd1e..cdf01d141bf6 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -65,11 +65,9 @@ def __init__(self, """ self._IS_CAPTURING = False self.disabled = True - assert \ - envs.VLLM_QUICK_ALLREDUCE in QuickAllreduce._SUPPORTED_LEVEL, ( + assert envs.VLLM_QUICK_ALLREDUCE in QuickAllreduce._SUPPORTED_LEVEL, ( "quick allreduce level must be in [0, 1, 2, 3, 4, 5], " - f"but got {envs.VLLM_QUICK_ALLREDUCE}" - ) + f"but got {envs.VLLM_QUICK_ALLREDUCE}") if not quick_ar: # disable because of missing quick allreduce library From fad70a50c1e8efeb949a452aea4d355f0833fbd1 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 22 May 2025 16:40:23 +0000 Subject: [PATCH 3/9] boundary condition Signed-off-by: Haoyang Li --- csrc/quick_all_reduce.cu | 2 -- .../device_communicators/quick_all_reduce.py | 34 +++++-------------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/csrc/quick_all_reduce.cu b/csrc/quick_all_reduce.cu index 76e52a19a3ef..341c93f77114 100644 --- a/csrc/quick_all_reduce.cu +++ b/csrc/quick_all_reduce.cu @@ -159,8 +159,6 @@ void DeviceComms::allreduce(int profile, hipStream_t stream, T const* A, T* B, break; } -#endif - // ------------------------------------------------- // Rotate the flag color. flag_color++; diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index cdf01d141bf6..4d0ff74cb22d 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -52,13 +52,15 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=512 * 1024 * 1024, - min_size=32 * 1024) -> None: + min_size=128 * 1024) -> None: """ Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the QuickAllreduce to. If None, it will be bind to f"cuda:{local_rank}". + max_size: max supported size. + min_size: Less than this size, custom_allreduce is better. It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -168,24 +170,11 @@ def should_quick_ar(self, inp: torch.Tensor): return inp_size < self.max_size # and inp_size > self.min_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): - """Performs an out-of-place all reduce. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. - """ + def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place all reduce.""" if out is None: out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - # print("qr") - ops.qr_all_reduce(self._ptr, envs.VLLM_QUICK_ALLREDUCE, inp, out) + ops.qr_all_reduce(self._ptr, envs.VLLM_QUICK_ALLREDUCE, inp, out) return out def quick_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -193,15 +182,8 @@ def quick_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # When quick allreduce is disabled, this will be None. if self.disabled or not self.should_quick_ar(input): return None - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) - else: - # If warm up, mimic the allocation pattern since quick - # allreduce is out-of-place. - return torch.empty_like(input) - else: - return self.all_reduce(input, registered=False) + + return self.all_reduce(input) def close(self): '''del self._ptr and del buffer''' From f66d957ea1213a053e5bc9cd7297e79faacd7ecd Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 22 May 2025 16:27:40 +0000 Subject: [PATCH 4/9] Platform inspection Signed-off-by: Haoyang Li --- csrc/torch_bindings.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index cb97ed3612c7..1ee8b43d4f65 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -702,7 +702,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("free_shared_buffer", &free_shared_buffer); } - +#ifdef USE_ROCM TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _quick_ar), quick_ar) { // quick all reduce kernels quick_ar.def("init_quick_ar(int world_size, int rank) -> int"); @@ -719,4 +719,5 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _quick_ar), quick_ar) { quick_ar.impl("qr_destroy", &qr_destroy); quick_ar.impl("is_quickreduce_available", &is_quickreduce_available); } +#endif REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From 0d97a5643b315befcf681f8528a3c0a76c5a954b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 23 May 2025 07:21:41 +0000 Subject: [PATCH 5/9] precommit Signed-off-by: Haoyang Li --- vllm/platforms/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f3e782704f14..faf293cc3d4a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -315,7 +315,7 @@ def use_custom_allreduce(cls) -> bool: @classmethod def use_quick_allreduce(cls) -> bool: return False - + @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa From 910d366c6e0861ac602d374129c92517c03103e0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 23 May 2025 14:33:06 +0000 Subject: [PATCH 6/9] adjust qr condition Signed-off-by: Haoyang Li --- .../device_communicators/quick_all_reduce.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 4d0ff74cb22d..09b782aee531 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -31,13 +31,12 @@ def is_weak_contiguous(inp: torch.Tensor): ''' quantization level & int -close qr = 0 -TwoShotF16 = 1 -TwoShotFP8 = 2 -TwoShotQ8 = 3 -TwoShotQ6 = 4 -TwoShotQ4 = 5 -OneShotQ4 = 6 +ONESHOT_F16 = 0, +TWOSHOT_F16 = 1, +TWOSHOT_FP8 = 2, +TWOSHOT_Q8 = 3, +TWOSHOT_Q6 = 4, +TWOSHOT_Q4 = 5, ''' @@ -61,6 +60,7 @@ def __init__(self, it will be bind to f"cuda:{local_rank}". max_size: max supported size. min_size: Less than this size, custom_allreduce is better. + (custom_allreduce is available when less than 16MB) It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -167,7 +167,7 @@ def should_quick_ar(self, inp: torch.Tensor): if not is_weak_contiguous(inp): return False if inp.dtype in QuickAllreduce._SUPPORTED_DTYPES: - return inp_size < self.max_size # and inp_size > self.min_size + return inp_size < self.max_size and inp_size > self.min_size return False def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): From 780e3554427e8ce6304ba56ab9bcaec93c482c64 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 24 May 2025 06:45:14 +0000 Subject: [PATCH 7/9] rebase Signed-off-by: Haoyang Li --- vllm/distributed/device_communicators/cuda_communicator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index ff64a3a402aa..26aa4ac366bf 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -81,7 +81,6 @@ def __init__(self, else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") - def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. From bb8f5138ed6c73dd1df9c4f8deca2a0c7267d415 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 25 May 2025 16:55:36 +0000 Subject: [PATCH 8/9] adjust condition Signed-off-by: Haoyang Li --- .../device_communicators/quick_all_reduce.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 09b782aee531..8530662dc88c 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -43,15 +43,13 @@ def is_weak_contiguous(inp: torch.Tensor): class QuickAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 8] - _SUPPORTED_DTYPES = [torch.float16, - torch.bfloat16] # TODO: support torch.float32 _SUPPORTED_LEVEL = [0, 1, 2, 3, 4, 5] def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=512 * 1024 * 1024, - min_size=128 * 1024) -> None: + max_size=1024 * 1024 * 512, + min_size=1024 * 1024) -> None: """ Args: group: the process group to work on. If None, it will use the @@ -93,6 +91,7 @@ def __init__(self, rank = dist.get_rank(group=self.group) self.rank = rank world_size = dist.get_world_size(group=self.group) + self.world_size = world_size if world_size not in QuickAllreduce._SUPPORTED_WORLD_SIZES: logger.warning( @@ -140,7 +139,7 @@ def __init__(self, "specify disable_quick_all_reduce=0 explicitly.") return - self.max_size = max_size + self.max_size = max_size if envs.VLLM_QUICK_ALLREDUCE > 0 else max_size / self.world_size * 2 self.min_size = min_size self._ptr = ops.init_quick_ar(world_size, rank) my_handle = ops.qr_get_comm_handle(self._ptr) @@ -166,8 +165,11 @@ def should_quick_ar(self, inp: torch.Tensor): return False if not is_weak_contiguous(inp): return False - if inp.dtype in QuickAllreduce._SUPPORTED_DTYPES: - return inp_size < self.max_size and inp_size > self.min_size + if inp.dtype == torch.float16: + return inp_size <= self.max_size and inp_size > self.min_size + elif inp.dtype == torch.bfloat16: + return inp_size <= self.max_size and inp_size > 1024 * 1024 * 16 \ + and self.world_size == 2 return False def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): From 3458abca3c1790899896be78c2f1854db802e189 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 25 May 2025 17:01:36 +0000 Subject: [PATCH 9/9] rebase Signed-off-by: Haoyang Li --- vllm/distributed/device_communicators/quick_all_reduce.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 8530662dc88c..3fdfdc71ae05 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -139,7 +139,8 @@ def __init__(self, "specify disable_quick_all_reduce=0 explicitly.") return - self.max_size = max_size if envs.VLLM_QUICK_ALLREDUCE > 0 else max_size / self.world_size * 2 + self.max_size = (max_size if envs.VLLM_QUICK_ALLREDUCE > 0 else + max_size / self.world_size * 2) self.min_size = min_size self._ptr = ops.init_quick_ar(world_size, rank) my_handle = ops.qr_get_comm_handle(self._ptr)