From b69929be0bef349d1a2f0488d907c60960f81c2f Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 17 Apr 2025 21:55:00 +0000 Subject: [PATCH 01/28] Add quickreduce as alternative to custom allreduce Signed-off-by: ilmarkov --- CMakeLists.txt | 2 + csrc/custom_quickreduce.cu | 73 + csrc/ops.h | 10 + csrc/quickreduce/base.h | 92 ++ csrc/quickreduce/quick_reduce.cu | 167 ++ csrc/quickreduce/quick_reduce.h | 72 + csrc/quickreduce/quick_reduce_impl.cuh | 1395 +++++++++++++++++ csrc/torch_bindings.cpp | 16 + vllm/_custom_ops.py | 26 + .../device_communicators/cuda_communicator.py | 20 + .../device_communicators/quick_all_reduce.py | 127 ++ 11 files changed, 2000 insertions(+) create mode 100644 csrc/custom_quickreduce.cu create mode 100644 csrc/quickreduce/base.h create mode 100644 csrc/quickreduce/quick_reduce.cu create mode 100644 csrc/quickreduce/quick_reduce.h create mode 100644 csrc/quickreduce/quick_reduce_impl.cuh create mode 100644 vllm/distributed/device_communicators/quick_all_reduce.py diff --git a/CMakeLists.txt b/CMakeLists.txt index d75f0d321247..a352de1d69bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -253,6 +253,8 @@ set(VLLM_EXT_SRC "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" + "csrc/custom_quickreduce.cu" + "csrc/quickreduce/quick_reduce.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu new file mode 100644 index 000000000000..8d1c437a3652 --- /dev/null +++ b/csrc/custom_quickreduce.cu @@ -0,0 +1,73 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + + #include "quickreduce/quick_reduce.h" + +fptr_t init_custom_qr(int64_t rank, int64_t world_size) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + DeviceComms* fptr = new DeviceComms(); + fptr->init(world_size, rank); + return (fptr_t)fptr; +} + +void qr_destroy(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; +} + +torch::Tensor qr_get_handle(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto device_index = c10::cuda::current_device(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(fptr_t _fa, const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t algo_int) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.scalar_type(), at::ScalarType::Half) + << "QR only supports half precision for now."; + TORCH_CHECK_EQ(inp.numel(), out.numel()); + + auto algo = static_cast(algo_int); + fa->allreduce(algo_int, stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); +} + +int64_t qr_max_size() { + return static_cast(DeviceComms::kMaxProblemSize); +} + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac19..3f540b88cb9a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -360,3 +360,13 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t qr_init_device_collectives(int64_t rank, int64_t world_size); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t algo_int); +int64_t qr_max_size(); +#endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h new file mode 100644 index 000000000000..df6a10b0a9e6 --- /dev/null +++ b/csrc/quickreduce/base.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#define __device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 + #define MUBUF_ACQUIRE 16 + #define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit + #define MUBUF_ACQUIRE 1 + #define MUBUF_RELEASE 0 +#endif + +// Vector types +using int8x8_t = __attribute__((__vector_size__(8 * sizeof(int8_t)))) int8_t; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; +using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; +using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + +using fp8_t = uint8_t; +using fp8x8_t = __attribute__((__vector_size__(8 * sizeof(uint8_t)))) uint8_t; + +using fp16x4_t = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16; +using fp16x8_t = __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16; +using fp16x16_t = __attribute__((__vector_size__(16 * sizeof(__fp16)))) __fp16; + +using fp32x2_t = __attribute__((__vector_size__(2 * sizeof(float)))) float; +using fp32x4_t = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using fp32x8_t = __attribute__((__vector_size__(8 * sizeof(float)))) float; +using fp32x16_t = __attribute__((__vector_size__(16 * sizeof(float)))) float; + +// Standard CDNA wavefront size. +static int constexpr kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlock = {64, 4, 1}; + +// Methods +__device_inline__ __host__ unsigned long divceil(unsigned long x, + unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __device_inline__ constexpr BufferResource() : config(0x00020000U) {} + + __device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__device_inline__ static void buffer_store_dwordx4( + int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__device_inline__ static void set_fp16_ovfl(bool const value) { + // short size = 0b00001; // Specifies the bit size to modify + // const short offset = 0b10111; // Corrected offset to 23, which is the bit + // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for + // MODE const short simm16 = (size << 11) | (offset << 6) | hwRegId; simm16 = + // 0xdc1 + +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} diff --git a/csrc/quickreduce/quick_reduce.cu b/csrc/quickreduce/quick_reduce.cu new file mode 100644 index 000000000000..610f8dc1056c --- /dev/null +++ b/csrc/quickreduce/quick_reduce.cu @@ -0,0 +1,167 @@ +#ifdef USE_ROCM + + #include + + #include "quick_reduce_impl.cuh" + #include "quick_reduce.h" + +void DeviceComms::init(int world_size, int rank) { + destroy(); + this->world_size = world_size; + this->rank = rank; + + // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. + long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); + long data_buffer_size = 2 * kMaxProblemSize; + long total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; +} + +void DeviceComms::destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } +} + +void DeviceComms::open_ipc_handles( + std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], + all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), + world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); +} + +// ============================================================ +// KERNEL +// ============================================================ +template +__global__ __quickreduce_launch_bounds__ static void allreduce_prototype( + half const* A, half* B, int N, int num_blocks, int world_size, int rank, + uint8_t** dbuffer_list, long data_offset, int flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKenel::run(A, B, N, block, num_blocks, world_size, rank, + dbuffer_list, data_offset, flag_color); + block += grid; + } +} + + // ============================================================ + // DISPATCH + // ============================================================ + #define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec<2>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec<4>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec<8>; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + +void DeviceComms::allreduce(int profile, hipStream_t stream, half const* A, + half* B, int N) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + long msg_size = N * sizeof(half); + unsigned long num_blocks = divceil(msg_size, kTileSize); + unsigned long grid = min(304 * 4, num_blocks); + // ------------------------------------------------- + // All reduce dispatch. + QuickReduceProfile dprofile = static_cast(profile); + + switch (dprofile) { + case QuickReduceProfile::ONESHOT_FP16: + using AllReduceKernel = AllReduceOneshot; + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), + dim3(kBlock), 0, stream, A, B, N, num_blocks, + world_size, rank, dbuffer_list, data_offset, + flag_color); + break; + case QuickReduceProfile::TWOSHOT_FP8: + throw std::runtime_error("FP8 is not supported"); + // TWOSHOT_DISPATCH(TwoshotFP8LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q8: + TWOSHOT_DISPATCH(TwoshotQ8LineCodec) + break; + case QuickReduceProfile::TWOSHOT_MAX_MIN_Q8: + TWOSHOT_DISPATCH(TwoshotMaxMinQ8LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q6: + TWOSHOT_DISPATCH(TwoshotQ6LineCodec) + break; + case QuickReduceProfile::TWOSHOT_Q4: + TWOSHOT_DISPATCH(TwoshotQ4LineCodec) + break; + default: + TWOSHOT_DISPATCH(TwoshotFP16LineCodec) + break; + } + HIP_CHECK(cudaGetLastError()); + + // ------------------------------------------------- + // Rotate the flag color. + flag_color++; +} + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h new file mode 100644 index 000000000000..9636533e0abc --- /dev/null +++ b/csrc/quickreduce/quick_reduce.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \ + hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +enum QuickReduceProfile { + ONESHOT_FP16 = 0, + TWOSHOT_FP16 = 1, + TWOSHOT_FP8 = 2, + TWOSHOT_Q8 = 3, + TWOSHOT_Q6 = 4, + TWOSHOT_Q4 = 5, + TWOSHOT_MAX_MIN_Q8 = 6, +}; + +/* +=============================================================== +Desc: + Device Comms Handle +*/ +struct DeviceComms { + // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) + static long constexpr kTileSize = 256 * 16 * 8; + + // Max problem size is 8GB (in bytes) + static long long constexpr kMaxProblemSize = + static_cast(536870912) * 16; + static long constexpr kMaxTiles = kMaxProblemSize / kTileSize; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + int flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + long data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { destroy(); } + + void init(int world_size, int rank); + int get_world_size() { return world_size; } + int get_rank() { return rank; } + bool status() { return initialized; } + void destroy(); + + hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } + void open_ipc_handles(std::vector const& ipc_handles); + void allreduce(int profile, hipStream_t stream, half const* A, half* B, + int N); +}; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh new file mode 100644 index 000000000000..f29082d1e9e7 --- /dev/null +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -0,0 +1,1395 @@ +#pragma once + +#include +#include "base.h" + +// ============================================================ +// Oneshot +// ============================================================ +// MARK: Oneshot All Reduce +struct AllReduceOneshot { + // Fixed magic implementation. + // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of + // work. + static int constexpr kAtoms = 8; + + // Size and atom stride of data that the workgroup will process. + static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); + static int constexpr kAtomStride = 256; + + __device__ static void run( + half const* __restrict__ A, // input + half* __restrict__ B, // output + int const N, // number of elements + int const block, // this block's index + int const num_blocks, // total number of blocks + int const world_size, // total number of ranks + int const rank, // this rank's index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + int flag_color // Flag color for the network barrier + ) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + + long data_stride = num_blocks * kTileSize; + long flags_stride = num_blocks * sizeof(int); + + uint8_t* rank_buffer = buffer_list[rank]; + + // -------------------------------------------------------- + // Read A into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(A), N * sizeof(half)); + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Write rank data into this rank segment of every rank's communication + // buffer. + long comm_data_offset = + data_offset + rank * data_stride + block * kTileSize; + long comm_flags_offset = rank * flags_stride + block * sizeof(int); + + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color - 1) { + } + } + __syncthreads(); + + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data_offset); + for (int i = 0; i < kAtoms; i++) { + __builtin_nontemporal_store(tA[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + // Inform the other ranks that th data has been posted. + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + + // -------------------------------------------------------- + // Read and reduce the data from this rank's communication buffer. + int32x4_t tB[kAtoms]; + + { + int r = 0; + + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); + + for (int i = 0; i < kAtoms; i++) { + tB[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } + } + + for (int r = 1; r < world_size; r++) { + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } + + // Reduce. + for (int i = 0; i < kAtoms; i++) { + int32x4_t& tA_fragment = tA[i]; + int32x4_t& tB_fragment = tB[i]; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tB_fragment[0]) + : "v"(tB_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tB_fragment[1]) + : "v"(tB_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tB_fragment[2]) + : "v"(tB_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tB_fragment[3]) + : "v"(tB_fragment[3]), "v"(tA_fragment[3])); + } + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELAXED); + } + + // -------------------------------------------------------- + // Write the result to B. + BufferResource dst_buffer(B, N * sizeof(half)); + int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tB[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +// ============================================================ +// Twoshot +// ============================================================ +// MARK: FP16 Line Codec +template +struct TwoshotFP16LineCodec { + /* + Default FP16 line codec for Twoshot collectives. + No actual compression is involved. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of fp16x8_t (16B). + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileSize = 256 * kRankAtoms * sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + int const thread; + int const rank; + + __device_inline__ TwoshotFP16LineCodec(int thread, int rank) + : thread(thread), rank(rank) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// MARK: FP8 Line Codec +// template +// struct TwoshotFP8LineCodec { +// /* +// FP8 Line codec for Twoshot collectives. +// We quantize the FP16 data to block-scaled FP8 in blocks of 32. +// */ + +// static int constexpr kAtoms = 8; +// static int constexpr kAtomStride = 256; +// static int constexpr kWorldSize = world_size; + +// // Codec tile size process by this workgroup. +// // Each threads processes a fragment of fp16x8_t (16B), +// // into a fp8x8_t (8B) and a fp16 scale shared among 32 values. +// static int constexpr kRankAtoms = kAtoms / kWorldSize; +// static int constexpr kRankTileStride = 2176; +// static int constexpr kRankTileScaleOffset = 2048; +// static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + +// static int constexpr kRankBufferTileStride = +// kRankTileStride / sizeof(int32x4_t); + +// // Total tile size for the collective communication. +// static int constexpr kTileSize = kRankTileSize * kWorldSize; + +// // FP8 Maximum value (on AMD Instinct MI300X - float8_e4m3fnuz) +// static float constexpr kFP8Max = 240.0f; +// static int constexpr kScaleFactor = 0x1C441C44; // {1/240.0h, 1/240.0h} +// static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7} + +// int const thread; +// int const rank; +// int const group_leader; + +// __device_inline__ TwoshotFP8LineCodec(int thread, int rank) +// : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { +// static_assert(kRankTileSize % 16 == 0, +// "kRankTileSize must be 16B aligned."); +// set_fp16_ovfl(true); +// } + +// __device_inline__ void send(int32x4_t* __restrict__ send_buffer, +// int32x4_t const* __restrict__ data) { +// for (int k = 0; k < kRankAtoms; k++) { +// int32x4_t const atom = data[k]; + +// // abs(w) +// int32x4_t w; +// { +// half const* x = reinterpret_cast(&atom); +// half* y = reinterpret_cast(&w); +// for (int i = 0; i < 8; i++) { +// y[i] = __habs(x[i]); +// } +// } + +// // max(w) +// int wmax; +// { +// int a, b; +// int* dw = reinterpret_cast(&w); +// asm volatile("v_pk_max_f16 %0, %1, %2" +// : "=v"(a) +// : "v"(dw[0]), "v"(dw[1])); +// asm volatile("v_pk_max_f16 %0, %1, %2" +// : "=v"(b) +// : "v"(dw[2]), "v"(dw[3])); +// asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), +// "v"(b)); + +// // Reduce the max among a group of 8 threads +// // Note: This is basically 2 blocks of 32 values setup as the +// // upper/lower halves of the fp16x2_t +// for (int i = 1; i < 8; i <<= 1) { +// int x = __shfl_down(wmax, i); +// asm volatile("v_pk_max_f16 %0, %1, %2" +// : "=v"(wmax) +// : "v"(wmax), "v"(x)); +// } + +// // Share with the cohort +// wmax = __shfl(wmax, group_leader); +// } + +// // Derive scales +// int decoding_scale; +// int encoding_scale; +// asm volatile("v_pk_mul_f16 %0, %1, %2" +// : "=v"(decoding_scale) +// : "v"(wmax), "v"(kScaleFactor)); +// asm volatile("v_pk_add_f16 %0, %1, %2" +// : "=v"(encoding_scale) +// : "v"(decoding_scale), "v"(kScaleEpsilon)); +// encoding_scale = __builtin_bit_cast( +// int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + +// // Apply scales to get quantized values +// for (int i = 0; i < 4; i++) { +// asm volatile("v_pk_mul_f16 %0, %1, %2" +// : "=v"(w[i]) +// : "v"(atom[i]), "v"(encoding_scale)); +// } + +// // Convert to packed FP8 +// fp32x8_t wf; +// { +// half2 const* x = reinterpret_cast(&w); +// float2* y = reinterpret_cast(&wf); +// for (int i = 0; i < 4; i++) { +// y[i] = __half22float2(x[i]); +// } +// } + +// int32x2_t qw; +// qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[0], wf[1], qw[0], 0); +// qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[2], wf[3], qw[0], 1); +// qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[4], wf[5], qw[1], 0); +// qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[6], wf[7], qw[1], 1); + +// // Write quantized atom to send_buffer +// // note: only the group leader stores the scale +// uint8_t* atom_ptr = +// reinterpret_cast(send_buffer + k * +// kRankBufferTileStride); +// int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; +// int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + +// (thread / 8); + +// __builtin_nontemporal_store(qw, qw_ptr); +// if (threadIdx.x == group_leader) { +// __builtin_nontemporal_store(decoding_scale, qs_ptr); +// } +// } +// } + +// __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, +// int32x4_t* __restrict__ data) { +// for (int k = 0; k < kRankAtoms; k++) { +// // Directly read quantized atom from recv_buffer +// uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); +// int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; +// int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + +// (thread / 8); + +// int32x2_t qw = __builtin_nontemporal_load(qw_ptr); +// int qs = __builtin_nontemporal_load(qs_ptr); + +// *recv_buffer += kRankBufferTileStride; + +// // Unpack FP8 +// int32x4_t w; +// { +// for (int i = 0; i < 2; i++) { +// fp32x2_t wf0 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 0); +// fp32x2_t wf1 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 1); + +// asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" +// : "=v"(w[i * 2 + 0]) +// : "v"(wf0[0]), "v"(wf0[1])); +// asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" +// : "=v"(w[i * 2 + 1]) +// : "v"(wf1[0]), "v"(wf1[1])); +// } +// } + +// // Apply decoding scales +// for (int i = 0; i < 4; i++) { +// asm volatile("v_pk_mul_f16 %0, %1, %2" +// : "=v"(w[i]) +// : "v"(w[i]), "v"(qs)); +// } + +// // That's pretty much it... +// data[k] = w; +// } +// } +// }; + +// MARK: Q4 Line Codec +template +struct TwoshotQ4LineCodec { + /* + Int4-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int4 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 1152; + static int constexpr kRankTileScaleOffset = 1024; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q4 configuration + static int constexpr kScaleFactor = + 0xB000B000; // {-1/8.0h, -1/8.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xC800C800; // {-8, -8}, fp16x2_t + static int constexpr kRangeMax = 0x47004700; // {+7, +7}, fp16x2_t + static int constexpr kRangeBias = 0x00080008; // {+8, +8}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ4LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); + + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(wmax) + : "v"(wmax), "v"(x)); + + int y = __shfl_down(wmin, i); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(wmin) + : "v"(wmin), "v"(y)); + } + + half2 wmaxh2 = __builtin_bit_cast(half2, wmax); + half2 wminh2 = __builtin_bit_cast(half2, wmin); + half2 wblockmaxh2; + + wblockmaxh2.x = + __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) + ? wmaxh2.x + : wminh2.x; + wblockmaxh2.y = + __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) + ? wmaxh2.y + : wminh2.y; + wblockmax = __builtin_bit_cast(int, wblockmaxh2); + + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale; + int encoding_scale; + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(decoding_scale) + : "v"(wblockmax), "v"(kScaleFactor)); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(encoding_scale) + : "v"(decoding_scale), "v"(kScaleEpsilon)); + encoding_scale = __builtin_bit_cast( + int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(atom[i]), "v"(encoding_scale)); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMin)); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMax)); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + half* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q4), "v"(kHalf2_1032)); + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(qs)); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Q6 Line Codec +template +struct TwoshotQ6LineCodec { + /* + Int6-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int64 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 1664; + static int constexpr kRankTileQ2Offset = 1024; + static int constexpr kRankTileScaleOffset = 1536; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q4 configuration + static int constexpr kScaleFactor = + 0xA800A800; // {-1/32.0h, -1/32.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xD000D000; // {-32, -32}, fp16x2_t + static int constexpr kRangeMax = 0x4FC04FC0; // {+31, +31}, fp16x2_t + static int constexpr kRangeBias = 0x00200020; // {+32, +32}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ6LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); + + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(wmax) + : "v"(wmax), "v"(x)); + + int y = __shfl_down(wmin, i); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(wmin) + : "v"(wmin), "v"(y)); + } + + half2 wmaxh2 = __builtin_bit_cast(half2, wmax); + half2 wminh2 = __builtin_bit_cast(half2, wmin); + half2 wblockmaxh2; + + wblockmaxh2.x = + __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) + ? wmaxh2.x + : wminh2.x; + wblockmaxh2.y = + __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) + ? wmaxh2.y + : wminh2.y; + wblockmax = __builtin_bit_cast(int, wblockmaxh2); + + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale; + int encoding_scale; + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(decoding_scale) + : "v"(wblockmax), "v"(kScaleFactor)); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(encoding_scale) + : "v"(decoding_scale), "v"(kScaleEpsilon)); + encoding_scale = __builtin_bit_cast( + int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(atom[i]), "v"(encoding_scale)); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMin)); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMax)); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + half* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(qs)); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// MARK: Q8 Line Codec +template +struct TwoshotQ8LineCodec { + /* + Int8-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int8 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int8x8_t (8B) and a fp16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 2176; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // Q4 configuration + static int constexpr kScaleFactor = + 0xA000A000; // {-1/128.0h, -1/128.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0xD800D800; // {-128, -128}, fp16x2_t + static int constexpr kRangeMax = 0x57F057F0; // {+127, +127}, fp16x2_t + static int constexpr kRangeBias = 0x00800080; // {+128, +128}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotQ8LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // max(w), min(w) + int wmax, wmin, wblockmax; + { + int a, b; + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); + + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(wmax) + : "v"(wmax), "v"(x)); + + int y = __shfl_down(wmin, i); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(wmin) + : "v"(wmin), "v"(y)); + } + + half2 wmaxh2 = __builtin_bit_cast(half2, wmax); + half2 wminh2 = __builtin_bit_cast(half2, wmin); + half2 wblockmaxh2; + + wblockmaxh2.x = + __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) + ? wmaxh2.x + : wminh2.x; + wblockmaxh2.y = + __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) + ? wmaxh2.y + : wminh2.y; + wblockmax = __builtin_bit_cast(int, wblockmaxh2); + + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + } + + // Derive scales + int decoding_scale; + int encoding_scale; + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(decoding_scale) + : "v"(wblockmax), "v"(kScaleFactor)); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(encoding_scale) + : "v"(decoding_scale), "v"(kScaleEpsilon)); + encoding_scale = __builtin_bit_cast( + int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(atom[i]), "v"(encoding_scale)); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMin)); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMax)); + } + + // Convert from fp16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + half* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_add_i16 %0, %1, %2" + : "=v"(q[i]) + : "v"(q[i]), "v"(kRangeBias)); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1152 = + 0xE480E480; // {-1152.0, -1152.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q8), "v"(kHalf2_1152)); + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(qs)); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +template +struct TwoshotMaxMinQ8LineCodec { + /* + Int8-blocking Line codec for Twoshot collectives. + We quantize the FP16 data to block-scaled Int8 in blocks of 32. + */ + + static int constexpr kAtoms = 8; + static int constexpr kAtomStride = 256; + static int constexpr kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each thread processes a fragment of fp16x8_t (16B), + // into a int8x8_t (8B) and a fp16 zero and a fp16 scale shared among 32 + // values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + // 2048 + 128 + 128 + static int constexpr kRankTileStride = 2304; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTileZeroOffset = 2176; + static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTileSize = kRankTileSize * kWorldSize; + + // static int constexpr kScaleFactor = 0x1C001C00; // {1/256.0h, 1/256.0h}, + // fp16x2_t + static int constexpr kScaleFactor = + 0x1C041C04; // {1/255.0h, 1/255.0h}, fp16x2_t + static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t + static int constexpr kRangeMin = 0x00000000; // {0, 0}, fp16x2_t + static int constexpr kRangeMax = 0x5BF85BF8; // {+255, +255}, fp16x2_t + static int constexpr kRangeBias = 0x00800080; // {+128, +128}, int16x2_t + + int const thread; + int const rank; + int const group_leader; + + __device_inline__ TwoshotMaxMinQ8LineCodec(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { + static_assert(kRankTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + set_fp16_ovfl(true); + } + + __device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // max(w), min(w) + int wmax, wmin, wblockmax, wblockmin; + { + int a, b; + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); + + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(a) + : "v"(atom[0]), "v"(atom[1])); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(b) + : "v"(atom[2]), "v"(atom[3])); + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + + // Reduce the max among a group of 8 threads + // Note: This is basically 2 blocks of 32 values setup as the + // upper/lower halves of the fp16x2_t + for (int i = 1; i < 8; i <<= 1) { + int x = __shfl_down(wmax, i); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(wmax) + : "v"(wmax), "v"(x)); + + int y = __shfl_down(wmin, i); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(wmin) + : "v"(wmin), "v"(y)); + } + + // Share with the cohort + wblockmax = __shfl(wmax, group_leader); + wblockmin = __shfl(wmin, group_leader); + } + + // Derive zeros and scales + int decoding_zero = wblockmin; + int decoding_scale; + int encoding_scale; + static int constexpr kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(decoding_scale) + : "v"(kNegOne), "v"(decoding_zero), "v"(wblockmax)); + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(decoding_scale) + : "v"(decoding_scale), "v"(kScaleFactor)); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(encoding_scale) + : "v"(decoding_scale), "v"(kScaleEpsilon)); + encoding_scale = __builtin_bit_cast( + int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + + // // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(w[i]) + : "v"(kNegOne), "v"(decoding_zero), "v"(atom[i])); + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(encoding_scale)); + asm volatile("v_pk_max_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMin)); + asm volatile("v_pk_min_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(kRangeMax)); + } + + // Convert from fp16x8_t to uint8x8_t and pack into int32x2_t + int32x2_t qw; + { + unsigned char* qi = reinterpret_cast(&qw); + half* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (unsigned char)__half2float(wh[i]); + } + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + int* qz_ptr = + reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + __builtin_nontemporal_store(decoding_zero, qz_ptr); + } + } + } + + __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + int* qz_ptr = + reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + int qz = __builtin_nontemporal_load(qz_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack uint8x8_t into fp16x8_t + int32x4_t w; + { + half* wh = reinterpret_cast(&w); + unsigned char* qi = reinterpret_cast(&qw); +#pragma unroll + for (int i = 0; i < 8; i++) { + wh[i] = __float2half((float)qi[i]); + } + } + + // Apply decoding scales and zeros + for (int i = 0; i < 4; i++) { + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(qs)); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(w[i]), "v"(qz)); + } + + data[k] = w; + } + } +}; + +// MARK: Twoshot All Reduce +template +struct AllReduceTwoshot { + // Fixed magic implementation. + // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of + // work. + static int constexpr kAtoms = 8; + + // Size and atom stride of source/destination data that the workgroup will + // process. + static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); + static int constexpr kAtomStride = 256; + + static int constexpr kWorldSize = LineCodec::kWorldSize; + + __device__ static void run( + half const* __restrict__ A, // input + half* __restrict__ B, // output + int const N, // number of elements + int const block, // block index + int const num_blocks, // number of blocks + int const world_size, // unused - only kept around for API consistency + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + int flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + LineCodec codec(thread, rank); + + // -------------------------------------------------------- + // Read A into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(A), N * sizeof(half)); + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + int32x4_t* src = reinterpret_cast(const_cast(A)); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + long comm_data0_offset = data_offset + block * LineCodec::kTileSize; + long comm_data1_offset = + num_blocks * LineCodec::kTileSize + comm_data0_offset; + + long comm_flags0_offset = block * (kWorldSize * sizeof(int)); + long comm_flags1_offset = + num_blocks * (kWorldSize * sizeof(int)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + comm_data0_offset + rank * LineCodec::kRankTileSize); + codec.send(send_buffer, &tA[r * LineCodec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + int* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[LineCodec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data0_offset); + int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != + flag_color) { + } + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < LineCodec::kRankAtoms; i++) { + int32x4_t& tA_fragment = tA[i]; + int32x4_t& tR_fragment = tR[i]; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[0]) + : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[1]) + : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[2]) + : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[3]) + : "v"(tR_fragment[3]), "v"(tA_fragment[3])); + } + } + } + + // -------------------------------------------------------- + // Phase-2: Write the reduced segment to every other rank + // This is basically an all-gather. + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + comm_data1_offset + rank * LineCodec::kRankTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + int* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + } + + // -------------------------------------------------------- + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data1_offset); + int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != + flag_color) { + } + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * LineCodec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to B. + BufferResource dst_buffer(B, N * sizeof(half)); + int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + int32x4_t* dst = reinterpret_cast(B); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1ee..4683804b7de1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -725,6 +725,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); custom_ar.def("free_shared_buffer", &free_shared_buffer); +#ifdef USE_ROCM + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int algo_int) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr(int rank, int world_size) -> int"); + custom_ar.def("qr_destroy", &qr_destroy); + + custom_ar.def("qr_get_handle(int fa) -> Tensor"); + custom_ar.impl("qr_get_handle", torch::kCPU, &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + custom_ar.def("qr_max_size", &qr_max_size); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff992c33b309..6ef4624b2dcd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1758,6 +1758,32 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +# quick all reduce +def init_custom_qr(rank: int, world_size: int) -> int: + return torch.ops._C_custom_ar.init_custom_qr(rank, world_size) + + +def qr_destroy(fa: int) -> None: + torch.ops._C_custom_ar.qr_destroy(fa) + + +def qr_all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, + algo_int: int) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, algo_int) + + +def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops._C_custom_ar.qr_get_handle(fa) + + +def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.qr_open_handles(fa, handles) + + +def qr_max_size() -> int: + return torch.ops._C_custom_ar.qr_max_size() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 055d91690e67..013803e46a7b 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Optional import torch @@ -41,6 +42,8 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce, QuickReduceAlgo) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -56,6 +59,16 @@ def __init__(self, group=self.cpu_group, device=self.device, ) + self.use_quick_allreduce = os.environ.get("VLLM_USE_QUICK_ALLREDUCE", + "0") == "1" + if self.use_quick_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + qr_comm_algo = os.environ.get("VLLM_QUICK_ALLREDUCE_ALGO", + "TwoShot") + self.qr_comm_algo = QuickReduceAlgo[qr_comm_algo] + self.qr_comm = QuickAllReduce(group=self.cpu_group, + device=self.device, + algo=self.qr_comm_algo) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -81,6 +94,13 @@ def __init__(self, def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_allreduce(input_): + out = qr_comm.all_reduce(input_) + assert out is not None + return out + ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 000000000000..576e9d202f75 --- /dev/null +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Enum, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm import _custom_ops as ops + +logger = logging.getLogger(__name__) + +try: + ops.meta_size() + ops_available = True +except Exception: + # For CPUs + ops_available = False + + +class QuickReduceAlgo(Enum): + OneShot = 0 + TwoShot = 1 + TwoShot_FP8 = 2 + TwoShot_Q8 = 3 + TwoShot_Q6 = 4 + TwoShot_Q4 = 5 + TwoShot_MAX_MIN_Q8 = 6 + + +class QuickAllReduce: + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + algo: QuickReduceAlgo = QuickReduceAlgo.TwoShot) -> None: + self.disabled = True + if not ops_available: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + self.max_size = ops.qr_max_size() + self.group = group + self.algo = algo + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "QuickReduce should be attached to a non-NCCL group.") + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "QuickReduce allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + torch.cuda.set_device(self.device) + + self.disabled = False + self._ptr = ops.init_custom_qr(rank, world_size) + self.create_shared_buffer() + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after qr_init_device_collectives + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + inp_size = inp.numel() * inp.element_size() + if inp_size >= self.max_size: + return None + + if out is None: + out = torch.empty_like(inp) + + ops.qr_all_reduce(self._ptr, inp, out, self.algo.value) + return out + + def is_enabled(self): + return not self.disabled + + @staticmethod + def is_available(): + return ops_available + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + ops.qr_destroy(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() + + def should_quick_allreduce(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + return inp.dtype == torch.float16 and inp_size < self.max_size From 8660eea8901ad39099ddebf6c94051fada20615c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 5 May 2025 14:15:01 +0000 Subject: [PATCH 02/28] WIP Signed-off-by: ilmarkov --- csrc/ops.h | 2 +- csrc/torch_bindings.cpp | 5 ++--- vllm/distributed/device_communicators/cuda_communicator.py | 1 + vllm/distributed/device_communicators/quick_all_reduce.py | 3 ++- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 3f540b88cb9a..2d92f9b2fcb6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -362,7 +362,7 @@ int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); #ifdef USE_ROCM -fptr_t qr_init_device_collectives(int64_t rank, int64_t world_size); +fptr_t init_custom_qr(int64_t rank, int64_t world_size); void qr_destroy(fptr_t _fa); torch::Tensor qr_get_handle(fptr_t _fa); void qr_open_handles(fptr_t _fa, const std::vector& handles); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4683804b7de1..98f51c9cd88e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -730,11 +730,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { "qr_all_reduce(int fa, Tensor inp, Tensor out, int algo_int) -> ()"); custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); - custom_ar.def("init_custom_qr(int rank, int world_size) -> int"); + custom_ar.def("init_custom_qr", &init_custom_qr); custom_ar.def("qr_destroy", &qr_destroy); - custom_ar.def("qr_get_handle(int fa) -> Tensor"); - custom_ar.impl("qr_get_handle", torch::kCPU, &qr_get_handle); + custom_ar.def("qr_get_handle", &qr_get_handle); custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 013803e46a7b..26294187dbb8 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -61,6 +61,7 @@ def __init__(self, ) self.use_quick_allreduce = os.environ.get("VLLM_USE_QUICK_ALLREDUCE", "0") == "1" + self.qr_comm: Optional[QuickAllReduce] = None if self.use_quick_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. qr_comm_algo = os.environ.get("VLLM_QUICK_ALLREDUCE_ALGO", diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 576e9d202f75..40300c797d48 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Enum, Union +from enum import Enum +from typing import Union import torch import torch.distributed as dist From 92ab95045add851810453dc957dc8257067ad9d1 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 13 May 2025 13:26:06 +0000 Subject: [PATCH 03/28] Add bf16 support Signed-off-by: ilmarkov --- CMakeLists.txt | 10 +- csrc/custom_quickreduce.cu | 47 +- csrc/quickreduce/base.h | 213 +++++ csrc/quickreduce/quick_reduce.cu | 167 ---- csrc/quickreduce/quick_reduce.h | 175 +++- csrc/quickreduce/quick_reduce_impl.cuh | 749 ++++++------------ .../device_communicators/cuda_communicator.py | 16 +- .../device_communicators/quick_all_reduce.py | 46 +- 8 files changed, 671 insertions(+), 752 deletions(-) delete mode 100644 csrc/quickreduce/quick_reduce.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index a352de1d69bc..64092f004f78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -253,8 +253,6 @@ set(VLLM_EXT_SRC "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" - "csrc/custom_quickreduce.cu" - "csrc/quickreduce/quick_reduce.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -640,6 +638,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 8d1c437a3652..e3b7a6f69d7b 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -7,38 +7,39 @@ #include "quickreduce/quick_reduce.h" -fptr_t init_custom_qr(int64_t rank, int64_t world_size) { +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); - DeviceComms* fptr = new DeviceComms(); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); fptr->init(world_size, rank); - return (fptr_t)fptr; + return (quickreduce::fptr_t)fptr; } -void qr_destroy(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); +void qr_destroy(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); fa->destroy(); delete fa; } -torch::Tensor qr_get_handle(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); hipIpcMemHandle_t handle = fa->get_handle(); auto device_index = c10::cuda::current_device(); auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); auto data_handle = - torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); return data_handle; } -void qr_open_handles(fptr_t _fa, const std::vector& handles) { - auto fa = reinterpret_cast(_fa); +void qr_open_handles(quickreduce::fptr_t _fa, + const std::vector& handles) { + auto fa = reinterpret_cast(_fa); std::vector ipc_handles; ipc_handles.reserve(handles.size()); for (auto& handle : handles) { @@ -50,24 +51,32 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles) { fa->open_ipc_handles(ipc_handles); } -void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - int64_t algo_int) { - auto fa = reinterpret_cast(_fa); +void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, + torch::Tensor& out, int64_t algo_int) { + auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.scalar_type(), at::ScalarType::Half) - << "QR only supports half precision for now."; TORCH_CHECK_EQ(inp.numel(), out.numel()); - auto algo = static_cast(algo_int); - fa->allreduce(algo_int, stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + auto algo = static_cast(algo_int); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(algo_int, stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + fa->allreduce( + algo_int, stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16"); + } } int64_t qr_max_size() { - return static_cast(DeviceComms::kMaxProblemSize); + return static_cast(quickreduce::DeviceComms::kMaxProblemSize); } #endif // USE_ROCM \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index df6a10b0a9e6..3cfb7ae2cbb8 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -3,6 +3,11 @@ #include #include #include +#include + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; +namespace quickreduce { #define __device_inline__ __device__ __forceinline__ #define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) @@ -39,6 +44,8 @@ using fp32x4_t = __attribute__((__vector_size__(4 * sizeof(float)))) float; using fp32x8_t = __attribute__((__vector_size__(8 * sizeof(float)))) float; using fp32x16_t = __attribute__((__vector_size__(16 * sizeof(float)))) float; +static int constexpr kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + // Standard CDNA wavefront size. static int constexpr kWavefront = 64; @@ -90,3 +97,209 @@ __device_inline__ static void set_fp16_ovfl(bool const value) { } #endif } + +template +__device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); + +template <> +__device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[0]) + : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[1]) + : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[2]) + : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[3]) + : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__device_inline__ int packed_max(int a, int b); + +template <> +__device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__device_inline__ int packed_max(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmax2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__device_inline__ int packed_min(int a, int b); + +template <> +__device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__device_inline__ int packed_min(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmin2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__device_inline__ int packed_abs_max(int a, int b); + +template <> +__device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = + __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = + __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__device_inline__ int packed_abs_max(int a, int b) { + nv_bfloat162 wmaxh2 = *(reinterpret_cast(&a)); + nv_bfloat162 wminh2 = *(reinterpret_cast(&b)); + nv_bfloat162 wblockmaxh2; + wblockmaxh2.x = + __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = + __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + + return *(reinterpret_cast(&wblockmaxh2)); +} + +template +__device_inline__ int packed_add(int a, int b); + +template <> +__device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__device_inline__ int packed_add(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hadd2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template <> +__device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__device_inline__ int packed_sub(int a, int b); + +template <> +__device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(result) + : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__device_inline__ int packed_sub(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hsub2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__device_inline__ int packed_mul(int a, int b); + +template <> +__device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__device_inline__ int packed_rcp(int a); + +template <> +__device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__device_inline__ int packed_rcp(int a) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162 tR = h2rcp(*tA); + return *(reinterpret_cast(&tR)); +} + +template +__device_inline__ T float2T_cast(float a); + +template <> +__device_inline__ half float2T_cast(float a) { + return __float2half(a); +} + +template <> +__device_inline__ nv_bfloat16 float2T_cast(float a) { + return __float2bfloat16(a); +} + +template +__device_inline__ float T2float_cast(T a); + +template <> +__device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +template <> +__device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.cu b/csrc/quickreduce/quick_reduce.cu deleted file mode 100644 index 610f8dc1056c..000000000000 --- a/csrc/quickreduce/quick_reduce.cu +++ /dev/null @@ -1,167 +0,0 @@ -#ifdef USE_ROCM - - #include - - #include "quick_reduce_impl.cuh" - #include "quick_reduce.h" - -void DeviceComms::init(int world_size, int rank) { - destroy(); - this->world_size = world_size; - this->rank = rank; - - // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. - long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - long data_buffer_size = 2 * kMaxProblemSize; - long total_buffer_size = flags_buffer_size + data_buffer_size; - data_offset = flags_buffer_size; - HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, - hipDeviceMallocUncached)); - - // Clear the flags buffer. - HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); - - // Device-side list of IPC buffers. - buffer_list.resize(world_size); - HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); - - // Create IPC handles for rank's communication buffer. - all_buffer_ipc_handles.resize(world_size); - HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); - - initialized = true; -} - -void DeviceComms::destroy() { - if (initialized) { - for (int i = 0; i < world_size; i++) { - if (i != rank) { - HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); - } - } - - HIP_CHECK(hipFree(dbuffer)); - HIP_CHECK(hipFree(dbuffer_list)); - - initialized = false; - } -} - -void DeviceComms::open_ipc_handles( - std::vector const& ipc_handles) { - assert(ipc_handles.size() == all_buffer_ipc_handles.size()); - for (int i = 0; i < world_size; i++) { - all_buffer_ipc_handles[i] = ipc_handles[i]; - } - - // Open device memory access to the IPC communication buffers. - // Note: For our own rank, we do not need to open a handle. - for (int i = 0; i < world_size; i++) { - if (i != rank) { - HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], - all_buffer_ipc_handles[i], - hipIpcMemLazyEnablePeerAccess)); - } else { - buffer_list[i] = dbuffer; - } - } - - HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), - world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); -} - -// ============================================================ -// KERNEL -// ============================================================ -template -__global__ __quickreduce_launch_bounds__ static void allreduce_prototype( - half const* A, half* B, int N, int num_blocks, int world_size, int rank, - uint8_t** dbuffer_list, long data_offset, int flag_color) { - int block = blockIdx.x; - int grid = gridDim.x; - - while (block < num_blocks) { - AllReduceKenel::run(A, B, N, block, num_blocks, world_size, rank, - dbuffer_list, data_offset, flag_color); - block += grid; - } -} - - // ============================================================ - // DISPATCH - // ============================================================ - #define TWOSHOT_DISPATCH(__codec) \ - if (world_size == 2) { \ - using LineCodec = __codec<2>; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ - } else if (world_size == 4) { \ - using LineCodec = __codec<4>; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ - } else if (world_size == 8) { \ - using LineCodec = __codec<8>; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ - } - -void DeviceComms::allreduce(int profile, hipStream_t stream, half const* A, - half* B, int N) { - if (world_size != 2 && world_size != 4 && world_size != 8) { - throw std::runtime_error("All Reduce not supported for world_size = " + - std::to_string(world_size)); - } - - // Configuration. - long msg_size = N * sizeof(half); - unsigned long num_blocks = divceil(msg_size, kTileSize); - unsigned long grid = min(304 * 4, num_blocks); - // ------------------------------------------------- - // All reduce dispatch. - QuickReduceProfile dprofile = static_cast(profile); - - switch (dprofile) { - case QuickReduceProfile::ONESHOT_FP16: - using AllReduceKernel = AllReduceOneshot; - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), - dim3(kBlock), 0, stream, A, B, N, num_blocks, - world_size, rank, dbuffer_list, data_offset, - flag_color); - break; - case QuickReduceProfile::TWOSHOT_FP8: - throw std::runtime_error("FP8 is not supported"); - // TWOSHOT_DISPATCH(TwoshotFP8LineCodec) - break; - case QuickReduceProfile::TWOSHOT_Q8: - TWOSHOT_DISPATCH(TwoshotQ8LineCodec) - break; - case QuickReduceProfile::TWOSHOT_MAX_MIN_Q8: - TWOSHOT_DISPATCH(TwoshotMaxMinQ8LineCodec) - break; - case QuickReduceProfile::TWOSHOT_Q6: - TWOSHOT_DISPATCH(TwoshotQ6LineCodec) - break; - case QuickReduceProfile::TWOSHOT_Q4: - TWOSHOT_DISPATCH(TwoshotQ4LineCodec) - break; - default: - TWOSHOT_DISPATCH(TwoshotFP16LineCodec) - break; - } - HIP_CHECK(cudaGetLastError()); - - // ------------------------------------------------- - // Rotate the flag color. - flag_color++; -} - -#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 9636533e0abc..6f24f139d705 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -2,7 +2,7 @@ #include #include -#include +#include "quick_reduce_impl.cuh" #define HIP_CHECK(err) \ do { \ @@ -14,19 +14,63 @@ } \ } while (0) +namespace quickreduce { using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -enum QuickReduceProfile { +enum QuickReduceAlgo { ONESHOT_FP16 = 0, TWOSHOT_FP16 = 1, - TWOSHOT_FP8 = 2, - TWOSHOT_Q8 = 3, - TWOSHOT_Q6 = 4, - TWOSHOT_Q4 = 5, - TWOSHOT_MAX_MIN_Q8 = 6, + TWOSHOT_Q8 = 2, + TWOSHOT_Q6 = 3, + TWOSHOT_Q4 = 4, + TWOSHOT_MAX_MIN_Q8 = 5, }; +// ============================================================ +// KERNEL +// ============================================================ +template +__global__ __quickreduce_launch_bounds__ static void allreduce_prototype( + T const* A, T* B, int N, int num_blocks, int world_size, int rank, + uint8_t** dbuffer_list, long data_offset, int flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, num_blocks, world_size, rank, + dbuffer_list, data_offset, flag_color); + block += grid; + } +} + +// ============================================================ +// DISPATCH +// ============================================================ +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ + dim3(kBlock), 0, stream, A, B, N, num_blocks, \ + world_size, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + /* =============================================================== Desc: @@ -59,14 +103,119 @@ struct DeviceComms { DeviceComms() : initialized(false), world_size(1), rank(0) {} ~DeviceComms() { destroy(); } - void init(int world_size, int rank); + void init(int world_size, int rank) { + destroy(); + this->world_size = world_size; + this->rank = rank; + + // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. + long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); + long data_buffer_size = 2 * kMaxProblemSize; + long total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } int get_world_size() { return world_size; } int get_rank() { return rank; } bool status() { return initialized; } - void destroy(); - hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } - void open_ipc_handles(std::vector const& ipc_handles); - void allreduce(int profile, hipStream_t stream, half const* A, half* B, - int N); + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], + all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), + world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(int profile, hipStream_t stream, T const* A, T* B, int N) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + long msg_size = N * sizeof(T); + unsigned long num_blocks = divceil(msg_size, kTileSize); + unsigned long grid = min(304 * 4, num_blocks); + // ------------------------------------------------- + // All reduce dispatch. + QuickReduceAlgo algo = static_cast(profile); + + switch (algo) { + case QuickReduceAlgo::ONESHOT_FP16: + using AllReduceKernel = AllReduceOneshot; + hipLaunchKernelGGL((allreduce_prototype), + dim3(grid), dim3(kBlock), 0, stream, A, B, N, + num_blocks, world_size, rank, dbuffer_list, + data_offset, flag_color); + break; + case QuickReduceAlgo::TWOSHOT_Q8: + TWOSHOT_DISPATCH(TwoshotQ8LineCodec) + break; + case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q8: + TWOSHOT_DISPATCH(TwoshotMaxMinQ8LineCodec) + break; + case QuickReduceAlgo::TWOSHOT_Q6: + TWOSHOT_DISPATCH(TwoshotQ6LineCodec) + break; + case QuickReduceAlgo::TWOSHOT_Q4: + TWOSHOT_DISPATCH(TwoshotQ4LineCodec) + break; + default: + TWOSHOT_DISPATCH(TwoshotFP16LineCodec) + break; + } + HIP_CHECK(cudaGetLastError()); + + // ------------------------------------------------- + // Rotate the flag color. + flag_color++; + } }; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index f29082d1e9e7..b9715dc40823 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -3,14 +3,17 @@ #include #include "base.h" +namespace quickreduce { // ============================================================ // Oneshot // ============================================================ // MARK: Oneshot All Reduce +template struct AllReduceOneshot { // Fixed magic implementation. // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of // work. + static_assert(sizeof(T) == 2); static int constexpr kAtoms = 8; // Size and atom stride of data that the workgroup will process. @@ -18,8 +21,8 @@ struct AllReduceOneshot { static int constexpr kAtomStride = 256; __device__ static void run( - half const* __restrict__ A, // input - half* __restrict__ B, // output + T const* __restrict__ A, // input + T* __restrict__ B, // output int const N, // number of elements int const block, // this block's index int const num_blocks, // total number of blocks @@ -41,7 +44,8 @@ struct AllReduceOneshot { // Read A into registers int32x4_t tA[kAtoms]; - BufferResource src_buffer(const_cast(A), N * sizeof(half)); + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); for (int i = 0; i < kAtoms; i++) { @@ -130,20 +134,7 @@ struct AllReduceOneshot { // Reduce. for (int i = 0; i < kAtoms; i++) { - int32x4_t& tA_fragment = tA[i]; - int32x4_t& tB_fragment = tB[i]; - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tB_fragment[0]) - : "v"(tB_fragment[0]), "v"(tA_fragment[0])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tB_fragment[1]) - : "v"(tB_fragment[1]), "v"(tA_fragment[1])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tB_fragment[2]) - : "v"(tB_fragment[2]), "v"(tA_fragment[2])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tB_fragment[3]) - : "v"(tB_fragment[3]), "v"(tA_fragment[3])); + packed_assign_add(&tB[i], &tA[i]); } } @@ -157,7 +148,7 @@ struct AllReduceOneshot { // -------------------------------------------------------- // Write the result to B. - BufferResource dst_buffer(B, N * sizeof(half)); + BufferResource dst_buffer(B, N * sizeof(T)); int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); for (int i = 0; i < kAtoms; i++) { @@ -171,7 +162,7 @@ struct AllReduceOneshot { // Twoshot // ============================================================ // MARK: FP16 Line Codec -template +template struct TwoshotFP16LineCodec { /* Default FP16 line codec for Twoshot collectives. @@ -216,187 +207,8 @@ struct TwoshotFP16LineCodec { } }; -// MARK: FP8 Line Codec -// template -// struct TwoshotFP8LineCodec { -// /* -// FP8 Line codec for Twoshot collectives. -// We quantize the FP16 data to block-scaled FP8 in blocks of 32. -// */ - -// static int constexpr kAtoms = 8; -// static int constexpr kAtomStride = 256; -// static int constexpr kWorldSize = world_size; - -// // Codec tile size process by this workgroup. -// // Each threads processes a fragment of fp16x8_t (16B), -// // into a fp8x8_t (8B) and a fp16 scale shared among 32 values. -// static int constexpr kRankAtoms = kAtoms / kWorldSize; -// static int constexpr kRankTileStride = 2176; -// static int constexpr kRankTileScaleOffset = 2048; -// static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; - -// static int constexpr kRankBufferTileStride = -// kRankTileStride / sizeof(int32x4_t); - -// // Total tile size for the collective communication. -// static int constexpr kTileSize = kRankTileSize * kWorldSize; - -// // FP8 Maximum value (on AMD Instinct MI300X - float8_e4m3fnuz) -// static float constexpr kFP8Max = 240.0f; -// static int constexpr kScaleFactor = 0x1C441C44; // {1/240.0h, 1/240.0h} -// static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7} - -// int const thread; -// int const rank; -// int const group_leader; - -// __device_inline__ TwoshotFP8LineCodec(int thread, int rank) -// : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { -// static_assert(kRankTileSize % 16 == 0, -// "kRankTileSize must be 16B aligned."); -// set_fp16_ovfl(true); -// } - -// __device_inline__ void send(int32x4_t* __restrict__ send_buffer, -// int32x4_t const* __restrict__ data) { -// for (int k = 0; k < kRankAtoms; k++) { -// int32x4_t const atom = data[k]; - -// // abs(w) -// int32x4_t w; -// { -// half const* x = reinterpret_cast(&atom); -// half* y = reinterpret_cast(&w); -// for (int i = 0; i < 8; i++) { -// y[i] = __habs(x[i]); -// } -// } - -// // max(w) -// int wmax; -// { -// int a, b; -// int* dw = reinterpret_cast(&w); -// asm volatile("v_pk_max_f16 %0, %1, %2" -// : "=v"(a) -// : "v"(dw[0]), "v"(dw[1])); -// asm volatile("v_pk_max_f16 %0, %1, %2" -// : "=v"(b) -// : "v"(dw[2]), "v"(dw[3])); -// asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), -// "v"(b)); - -// // Reduce the max among a group of 8 threads -// // Note: This is basically 2 blocks of 32 values setup as the -// // upper/lower halves of the fp16x2_t -// for (int i = 1; i < 8; i <<= 1) { -// int x = __shfl_down(wmax, i); -// asm volatile("v_pk_max_f16 %0, %1, %2" -// : "=v"(wmax) -// : "v"(wmax), "v"(x)); -// } - -// // Share with the cohort -// wmax = __shfl(wmax, group_leader); -// } - -// // Derive scales -// int decoding_scale; -// int encoding_scale; -// asm volatile("v_pk_mul_f16 %0, %1, %2" -// : "=v"(decoding_scale) -// : "v"(wmax), "v"(kScaleFactor)); -// asm volatile("v_pk_add_f16 %0, %1, %2" -// : "=v"(encoding_scale) -// : "v"(decoding_scale), "v"(kScaleEpsilon)); -// encoding_scale = __builtin_bit_cast( -// int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); - -// // Apply scales to get quantized values -// for (int i = 0; i < 4; i++) { -// asm volatile("v_pk_mul_f16 %0, %1, %2" -// : "=v"(w[i]) -// : "v"(atom[i]), "v"(encoding_scale)); -// } - -// // Convert to packed FP8 -// fp32x8_t wf; -// { -// half2 const* x = reinterpret_cast(&w); -// float2* y = reinterpret_cast(&wf); -// for (int i = 0; i < 4; i++) { -// y[i] = __half22float2(x[i]); -// } -// } - -// int32x2_t qw; -// qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[0], wf[1], qw[0], 0); -// qw[0] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[2], wf[3], qw[0], 1); -// qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[4], wf[5], qw[1], 0); -// qw[1] = __builtin_amdgcn_cvt_pk_fp8_f32(wf[6], wf[7], qw[1], 1); - -// // Write quantized atom to send_buffer -// // note: only the group leader stores the scale -// uint8_t* atom_ptr = -// reinterpret_cast(send_buffer + k * -// kRankBufferTileStride); -// int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; -// int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + -// (thread / 8); - -// __builtin_nontemporal_store(qw, qw_ptr); -// if (threadIdx.x == group_leader) { -// __builtin_nontemporal_store(decoding_scale, qs_ptr); -// } -// } -// } - -// __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, -// int32x4_t* __restrict__ data) { -// for (int k = 0; k < kRankAtoms; k++) { -// // Directly read quantized atom from recv_buffer -// uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); -// int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; -// int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + -// (thread / 8); - -// int32x2_t qw = __builtin_nontemporal_load(qw_ptr); -// int qs = __builtin_nontemporal_load(qs_ptr); - -// *recv_buffer += kRankBufferTileStride; - -// // Unpack FP8 -// int32x4_t w; -// { -// for (int i = 0; i < 2; i++) { -// fp32x2_t wf0 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 0); -// fp32x2_t wf1 = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 1); - -// asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" -// : "=v"(w[i * 2 + 0]) -// : "v"(wf0[0]), "v"(wf0[1])); -// asm volatile("v_cvt_pkrtz_f16_f32 %0, %1, %2" -// : "=v"(w[i * 2 + 1]) -// : "v"(wf1[0]), "v"(wf1[1])); -// } -// } - -// // Apply decoding scales -// for (int i = 0; i < 4; i++) { -// asm volatile("v_pk_mul_f16 %0, %1, %2" -// : "=v"(w[i]) -// : "v"(w[i]), "v"(qs)); -// } - -// // That's pretty much it... -// data[k] = w; -// } -// } -// }; - // MARK: Q4 Line Codec -template +template struct TwoshotQ4LineCodec { /* Int4-blocking Line codec for Twoshot collectives. @@ -421,13 +233,26 @@ struct TwoshotQ4LineCodec { // Total tile size for the collective communication. static int constexpr kTileSize = kRankTileSize * kWorldSize; - // Q4 configuration + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t static int constexpr kScaleFactor = - 0xB000B000; // {-1/8.0h, -1/8.0h}, fp16x2_t - static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t - static int constexpr kRangeMin = 0xC800C800; // {-8, -8}, fp16x2_t - static int constexpr kRangeMax = 0x47004700; // {+7, +7}, fp16x2_t - static int constexpr kRangeBias = 0x00080008; // {+8, +8}, int16x2_t + std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static int constexpr kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static int constexpr kRangeBias = 0x00080008; int const thread; int const rank; @@ -449,50 +274,25 @@ struct TwoshotQ4LineCodec { int wmax, wmin, wblockmax; { int a, b; - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); - - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + wmin = packed_min(a, b); // Reduce the max among a group of 8 threads // Note: This is basically 2 blocks of 32 values setup as the // upper/lower halves of the fp16x2_t for (int i = 1; i < 8; i <<= 1) { int x = __shfl_down(wmax, i); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(wmax) - : "v"(wmax), "v"(x)); + wmax = packed_max(wmax, x); int y = __shfl_down(wmin, i); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(wmin) - : "v"(wmin), "v"(y)); + wmin = packed_min(wmin, y); } - - half2 wmaxh2 = __builtin_bit_cast(half2, wmax); - half2 wminh2 = __builtin_bit_cast(half2, wmin); - half2 wblockmaxh2; - - wblockmaxh2.x = - __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) - ? wmaxh2.x - : wminh2.x; - wblockmaxh2.y = - __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) - ? wmaxh2.y - : wminh2.y; - wblockmax = __builtin_bit_cast(int, wblockmaxh2); + wblockmax = packed_abs_max(wmax, wmin); // Share with the cohort wblockmax = __shfl(wblockmax, group_leader); @@ -501,40 +301,28 @@ struct TwoshotQ4LineCodec { // Derive scales int decoding_scale; int encoding_scale; - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(decoding_scale) - : "v"(wblockmax), "v"(kScaleFactor)); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(encoding_scale) - : "v"(decoding_scale), "v"(kScaleEpsilon)); - encoding_scale = __builtin_bit_cast( - int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); // Apply scales to get quantized values int32x4_t w; for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(atom[i]), "v"(encoding_scale)); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMin)); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMax)); + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } - // Convert from fp16x2_t to uint16x2_t + // Convert from f16x2_t to uint16x2_t int32x4_t q; { int16_t* qi = reinterpret_cast(&q); - half* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { - asm volatile("v_pk_add_i16 %0, %1, %2" - : "=v"(q[i]) - : "v"(q[i]), "v"(kRangeBias)); + q[i] = packed_add(q[i], kRangeBias); } } @@ -574,24 +362,22 @@ struct TwoshotQ4LineCodec { int32x4_t w; { static uint constexpr kMask000F = 0x000F000F; - static uint constexpr kHalf2_1024 = - 0x64006400; // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1032 = - 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + // {1024.0, 1024.0}, f16x2_t + static uint constexpr kF162_1024 = + std::is_same::value ? 0x64006400 : 0x44804480; + // {-1032.0, -1032.0}, f16x2_t + static uint constexpr kF162_1032 = + std::is_same::value ? 0xE408E408 : 0xC481C481; for (int i = 0; i < 4; i++) { - int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(q4), "v"(kHalf2_1032)); + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kF162_1024; + w[i] = packed_add(q4, kF162_1032); } } // Apply decoding scales for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(qs)); + w[i] = packed_mul(w[i], qs); } // That's pretty much it... @@ -600,12 +386,11 @@ struct TwoshotQ4LineCodec { } }; -// MARK: Q6 Line Codec -template +template struct TwoshotQ6LineCodec { /* Int6-blocking Line codec for Twoshot collectives. - We quantize the FP16 data to block-scaled Int64 in blocks of 32. + We quantize the FP16/BF16 data to block-scaled Int64 in blocks of 32. */ static int constexpr kAtoms = 8; @@ -613,7 +398,7 @@ struct TwoshotQ6LineCodec { static int constexpr kWorldSize = world_size; // Codec tile size process by this workgroup. - // Each threads processes a fragment of fp16x8_t (16B), + // Each threads processes a fragment of f16x8_t (16B), // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. static int constexpr kRankAtoms = kAtoms / kWorldSize; static int constexpr kRankTileStride = 1664; @@ -627,13 +412,25 @@ struct TwoshotQ6LineCodec { // Total tile size for the collective communication. static int constexpr kTileSize = kRankTileSize * kWorldSize; - // Q4 configuration + // Constants configuration + // {-1/32.0h, -1/32.0h}, f16x2_t static int constexpr kScaleFactor = - 0xA800A800; // {-1/32.0h, -1/32.0h}, fp16x2_t - static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t - static int constexpr kRangeMin = 0xD000D000; // {-32, -32}, fp16x2_t - static int constexpr kRangeMax = 0x4FC04FC0; // {+31, +31}, fp16x2_t - static int constexpr kRangeBias = 0x00200020; // {+32, +32}, int16x2_t + std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, f16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static int constexpr kRangeMin = + std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static int constexpr kRangeBias = 0x00200020; int const thread; int const rank; @@ -655,50 +452,25 @@ struct TwoshotQ6LineCodec { int wmax, wmin, wblockmax; { int a, b; - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); - - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + wmin = packed_min(a, b); // Reduce the max among a group of 8 threads // Note: This is basically 2 blocks of 32 values setup as the // upper/lower halves of the fp16x2_t for (int i = 1; i < 8; i <<= 1) { int x = __shfl_down(wmax, i); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(wmax) - : "v"(wmax), "v"(x)); + wmax = packed_max(wmax, x); int y = __shfl_down(wmin, i); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(wmin) - : "v"(wmin), "v"(y)); + wmin = packed_min(wmin, y); } - - half2 wmaxh2 = __builtin_bit_cast(half2, wmax); - half2 wminh2 = __builtin_bit_cast(half2, wmin); - half2 wblockmaxh2; - - wblockmaxh2.x = - __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) - ? wmaxh2.x - : wminh2.x; - wblockmaxh2.y = - __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) - ? wmaxh2.y - : wminh2.y; - wblockmax = __builtin_bit_cast(int, wblockmaxh2); + wblockmax = packed_abs_max(wmax, wmin); // Share with the cohort wblockmax = __shfl(wblockmax, group_leader); @@ -707,40 +479,28 @@ struct TwoshotQ6LineCodec { // Derive scales int decoding_scale; int encoding_scale; - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(decoding_scale) - : "v"(wblockmax), "v"(kScaleFactor)); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(encoding_scale) - : "v"(decoding_scale), "v"(kScaleEpsilon)); - encoding_scale = __builtin_bit_cast( - int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); // Apply scales to get quantized values int32x4_t w; for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(atom[i]), "v"(encoding_scale)); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMin)); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMax)); + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } // Convert from fp16x2_t to uint16x2_t int32x4_t q; { int16_t* qi = reinterpret_cast(&q); - half* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { - asm volatile("v_pk_add_i16 %0, %1, %2" - : "=v"(q[i]) - : "v"(q[i]), "v"(kRangeBias)); + q[i] = packed_add(q[i], kRangeBias); } } @@ -796,10 +556,12 @@ struct TwoshotQ6LineCodec { int32x4_t w; { static uint constexpr kMask000F = 0x000F000F; - static uint constexpr kHalf2_1024 = - 0x64006400; // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1056 = - 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + // {1024.0, 1024.0}, f16x2_t + static uint constexpr kf162_1024 = + std::is_same::value ? 0x64006400 : 0x44804480; + // {-1056.0, -1056.0}, f16x2_t + static uint constexpr kF162_1056 = + std::is_same::value ? 0xE420E420 : 0xC484C484; #pragma unroll for (int i = 0; i < 4; i++) { @@ -808,18 +570,14 @@ struct TwoshotQ6LineCodec { q4w >>= 4; q2w >>= 4; - int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(q6), "v"(kHalf2_1056)); + int32_t q6 = q4 | (q2 << 4) | kf162_1024; + w[i] = packed_add(q6, kF162_1056); } } // Apply decoding scales for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(qs)); + w[i] = packed_mul(w[i], qs); } // That's pretty much it... @@ -829,11 +587,11 @@ struct TwoshotQ6LineCodec { }; // MARK: Q8 Line Codec -template +template struct TwoshotQ8LineCodec { /* Int8-blocking Line codec for Twoshot collectives. - We quantize the FP16 data to block-scaled Int8 in blocks of 32. + We quantize the FP16/BF16 data to block-scaled Int8 in blocks of 32. */ static int constexpr kAtoms = 8; @@ -841,8 +599,8 @@ struct TwoshotQ8LineCodec { static int constexpr kWorldSize = world_size; // Codec tile size process by this workgroup. - // Each threads processes a fragment of fp16x8_t (16B), - // into a int8x8_t (8B) and a fp16 scale shared among 32 values. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. static int constexpr kRankAtoms = kAtoms / kWorldSize; static int constexpr kRankTileStride = 2176; static int constexpr kRankTileScaleOffset = 2048; @@ -854,13 +612,25 @@ struct TwoshotQ8LineCodec { // Total tile size for the collective communication. static int constexpr kTileSize = kRankTileSize * kWorldSize; - // Q4 configuration + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t static int constexpr kScaleFactor = - 0xA000A000; // {-1/128.0h, -1/128.0h}, fp16x2_t - static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t - static int constexpr kRangeMin = 0xD800D800; // {-128, -128}, fp16x2_t - static int constexpr kRangeMax = 0x57F057F0; // {+127, +127}, fp16x2_t - static int constexpr kRangeBias = 0x00800080; // {+128, +128}, int16x2_t + std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static int constexpr kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static int constexpr kRangeBias = 0x00800080; int const thread; int const rank; @@ -882,50 +652,27 @@ struct TwoshotQ8LineCodec { int wmax, wmin, wblockmax; { int a, b; - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); - - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); // Reduce the max among a group of 8 threads // Note: This is basically 2 blocks of 32 values setup as the // upper/lower halves of the fp16x2_t for (int i = 1; i < 8; i <<= 1) { int x = __shfl_down(wmax, i); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(wmax) - : "v"(wmax), "v"(x)); + wmax = packed_max(wmax, x); int y = __shfl_down(wmin, i); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(wmin) - : "v"(wmin), "v"(y)); + wmin = packed_min(wmin, y); } - - half2 wmaxh2 = __builtin_bit_cast(half2, wmax); - half2 wminh2 = __builtin_bit_cast(half2, wmin); - half2 wblockmaxh2; - - wblockmaxh2.x = - __half2float(__habs(wmaxh2.x)) > __half2float(__habs(wminh2.x)) - ? wmaxh2.x - : wminh2.x; - wblockmaxh2.y = - __half2float(__habs(wmaxh2.y)) > __half2float(__habs(wminh2.y)) - ? wmaxh2.y - : wminh2.y; - wblockmax = __builtin_bit_cast(int, wblockmaxh2); + wblockmax = packed_abs_max(wmax, wmin); // Share with the cohort wblockmax = __shfl(wblockmax, group_leader); @@ -934,40 +681,28 @@ struct TwoshotQ8LineCodec { // Derive scales int decoding_scale; int encoding_scale; - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(decoding_scale) - : "v"(wblockmax), "v"(kScaleFactor)); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(encoding_scale) - : "v"(decoding_scale), "v"(kScaleEpsilon)); - encoding_scale = __builtin_bit_cast( - int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); // Apply scales to get quantized values int32x4_t w; for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(atom[i]), "v"(encoding_scale)); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMin)); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMax)); + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } - // Convert from fp16x2_t to uint16x2_t + // Convert from f16x2_t to uint16x2_t int32x4_t q; { int16_t* qi = reinterpret_cast(&q); - half* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(__half2float(wh[i])); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { - asm volatile("v_pk_add_i16 %0, %1, %2" - : "=v"(q[i]) - : "v"(q[i]), "v"(kRangeBias)); + q[i] = packed_add(q[i], kRangeBias); } } @@ -1009,25 +744,25 @@ struct TwoshotQ8LineCodec { int32x4_t w; { static uint constexpr kMask00FF = 0x00FF00FF; - static uint constexpr kHalf2_1024 = - 0x64006400; // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1152 = - 0xE480E480; // {-1152.0, -1152.0}, fp16x2_t + + // {1024.0, 1024.0}, f16x2_t + static uint constexpr kF162_1024 = + std::is_same::value ? 0x64006400 : 0x44804480; + + // {-1152.0, -1152.0}, f16x2_t + static uint constexpr kF162_1152 = + std::is_same::value ? 0xE480E480 : 0xC490C490; #pragma unroll for (int i = 0; i < 4; i++) { - int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(q8), "v"(kHalf2_1152)); + int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kF162_1024; + w[i] = packed_add(q8, kF162_1152); } } // Apply decoding scales for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(qs)); + w[i] = packed_mul(w[i], qs); } // That's pretty much it... @@ -1036,7 +771,7 @@ struct TwoshotQ8LineCodec { } }; -template +template struct TwoshotMaxMinQ8LineCodec { /* Int8-blocking Line codec for Twoshot collectives. @@ -1064,14 +799,24 @@ struct TwoshotMaxMinQ8LineCodec { // Total tile size for the collective communication. static int constexpr kTileSize = kRankTileSize * kWorldSize; - // static int constexpr kScaleFactor = 0x1C001C00; // {1/256.0h, 1/256.0h}, - // fp16x2_t + // Constants configuration + // {1/255.0h, 1/255.0h}, f16x2_t static int constexpr kScaleFactor = - 0x1C041C04; // {1/255.0h, 1/255.0h}, fp16x2_t - static int constexpr kScaleEpsilon = 0x00010001; // {1e-7, 1e-7}, fp16x2_t - static int constexpr kRangeMin = 0x00000000; // {0, 0}, fp16x2_t - static int constexpr kRangeMax = 0x5BF85BF8; // {+255, +255}, fp16x2_t - static int constexpr kRangeBias = 0x00800080; // {+128, +128}, int16x2_t + std::is_same::value ? 0x1C041C04 : 0x3B813B81; + + // {1e-7, 1e-7}, fp16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {0, 0}, f16x2_t + static int constexpr kRangeMin = 0x00000000; + + // {+255, +255}, f16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x5BF85BF8 : 0x437F437F; + + // {+128, +128}, int16x2_t + static int constexpr kRangeBias = 0x00800080; int const thread; int const rank; @@ -1092,35 +837,23 @@ struct TwoshotMaxMinQ8LineCodec { int wmax, wmin, wblockmax, wblockmin; { int a, b; - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(wmax) : "v"(a), "v"(b)); - - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(a) - : "v"(atom[0]), "v"(atom[1])); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(b) - : "v"(atom[2]), "v"(atom[3])); - asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(wmin) : "v"(a), "v"(b)); + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + wmin = packed_min(a, b); // Reduce the max among a group of 8 threads // Note: This is basically 2 blocks of 32 values setup as the // upper/lower halves of the fp16x2_t for (int i = 1; i < 8; i <<= 1) { int x = __shfl_down(wmax, i); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(wmax) - : "v"(wmax), "v"(x)); + wmax = packed_max(wmax, x); int y = __shfl_down(wmin, i); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(wmin) - : "v"(wmin), "v"(y)); + wmin = packed_min(wmin, y); } // Share with the cohort @@ -1132,44 +865,28 @@ struct TwoshotMaxMinQ8LineCodec { int decoding_zero = wblockmin; int decoding_scale; int encoding_scale; - static int constexpr kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t - - // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max - asm volatile("v_pk_fma_f16 %0, %1, %2 %3" - : "=v"(decoding_scale) - : "v"(kNegOne), "v"(decoding_zero), "v"(wblockmax)); - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(decoding_scale) - : "v"(decoding_scale), "v"(kScaleFactor)); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(encoding_scale) - : "v"(decoding_scale), "v"(kScaleEpsilon)); - encoding_scale = __builtin_bit_cast( - int, h2rcp(__builtin_bit_cast(half2, encoding_scale))); - - // // Apply scales to get quantized values + + decoding_scale = packed_sub(wblockmax, decoding_zero); + decoding_scale = packed_mul(decoding_scale, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values int32x4_t w; for (int i = 0; i < 4; i++) { - asm volatile("v_pk_fma_f16 %0, %1, %2 %3" - : "=v"(w[i]) - : "v"(kNegOne), "v"(decoding_zero), "v"(atom[i])); - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(encoding_scale)); - asm volatile("v_pk_max_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMin)); - asm volatile("v_pk_min_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(kRangeMax)); + w[i] = packed_sub(atom[i], decoding_zero); + w[i] = packed_mul(w[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } // Convert from fp16x8_t to uint8x8_t and pack into int32x2_t int32x2_t qw; { unsigned char* qi = reinterpret_cast(&qw); - half* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) qi[i] = (unsigned char)__half2float(wh[i]); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (unsigned char)T2float_cast(wh[i]); } // Write quantized atom to send_buffer @@ -1210,22 +927,18 @@ struct TwoshotMaxMinQ8LineCodec { // Unpack uint8x8_t into fp16x8_t int32x4_t w; { - half* wh = reinterpret_cast(&w); + T* wh = reinterpret_cast(&w); unsigned char* qi = reinterpret_cast(&qw); #pragma unroll for (int i = 0; i < 8; i++) { - wh[i] = __float2half((float)qi[i]); + wh[i] = float2T_cast((float)qi[i]); } } // Apply decoding scales and zeros for (int i = 0; i < 4; i++) { - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(qs)); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(w[i]), "v"(qz)); + w[i] = packed_mul(w[i], qs); + w[i] = packed_add(w[i], qz); } data[k] = w; @@ -1234,7 +947,7 @@ struct TwoshotMaxMinQ8LineCodec { }; // MARK: Twoshot All Reduce -template +template struct AllReduceTwoshot { // Fixed magic implementation. // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of @@ -1249,13 +962,13 @@ struct AllReduceTwoshot { static int constexpr kWorldSize = LineCodec::kWorldSize; __device__ static void run( - half const* __restrict__ A, // input - half* __restrict__ B, // output - int const N, // number of elements - int const block, // block index - int const num_blocks, // number of blocks - int const world_size, // unused - only kept around for API consistency - int const rank, // rank index + T const* __restrict__ A, // input + T* __restrict__ B, // output + int const N, // number of elements + int const block, // block index + int const num_blocks, // number of blocks + int const world_size, // unused - only kept around for API consistency + int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers long const data_offset, // offset to start of the data buffer int flag_color) { @@ -1268,9 +981,9 @@ struct AllReduceTwoshot { // Read A into registers int32x4_t tA[kAtoms]; - BufferResource src_buffer(const_cast(A), N * sizeof(half)); + BufferResource src_buffer(const_cast(A), N * sizeof(T)); int src_offset = block * kTileSize + thread * sizeof(int32x4_t); - int32x4_t* src = reinterpret_cast(const_cast(A)); + int32x4_t* src = reinterpret_cast(const_cast(A)); for (int i = 0; i < kAtoms; i++) { tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); @@ -1323,21 +1036,7 @@ struct AllReduceTwoshot { codec.recv(&recv_buffer, tA); for (int i = 0; i < LineCodec::kRankAtoms; i++) { - int32x4_t& tA_fragment = tA[i]; - int32x4_t& tR_fragment = tR[i]; - - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tR_fragment[0]) - : "v"(tR_fragment[0]), "v"(tA_fragment[0])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tR_fragment[1]) - : "v"(tR_fragment[1]), "v"(tA_fragment[1])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tR_fragment[2]) - : "v"(tR_fragment[2]), "v"(tA_fragment[2])); - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(tR_fragment[3]) - : "v"(tR_fragment[3]), "v"(tA_fragment[3])); + packed_assign_add(&tR[i], &tA[i]); } } } @@ -1383,7 +1082,7 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Write the result to B. - BufferResource dst_buffer(B, N * sizeof(half)); + BufferResource dst_buffer(B, N * sizeof(T)); int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); int32x4_t* dst = reinterpret_cast(B); @@ -1393,3 +1092,5 @@ struct AllReduceTwoshot { } } }; + +} // namespace quickreduce \ No newline at end of file diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 26294187dbb8..689b1262b2ca 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging import os from typing import Optional import torch from torch.distributed import ProcessGroup +from vllm.platforms import current_platform + import vllm.envs as envs from vllm.logger import init_logger @@ -61,11 +64,20 @@ def __init__(self, ) self.use_quick_allreduce = os.environ.get("VLLM_USE_QUICK_ALLREDUCE", "0") == "1" + if self.use_quick_allreduce and not current_platform.is_rocm(): + logger.warning( + "Environment variable VLLM_USE_QUICK_ALLREDUCE is set to 1," \ + " but QuickReduce is only supported on ROCm platform." + ) self.qr_comm: Optional[QuickAllReduce] = None - if self.use_quick_allreduce and self.world_size > 1: + if self.use_quick_allreduce and self.world_size > 1 and \ + current_platform.is_rocm(): # Initialize a custom fast all-reduce implementation. qr_comm_algo = os.environ.get("VLLM_QUICK_ALLREDUCE_ALGO", "TwoShot") + assert qr_comm_algo in QuickReduceAlgo.__members__, \ + "Unknown QuickReduce algorithm: {}".format( + qr_comm_algo) self.qr_comm_algo = QuickReduceAlgo[qr_comm_algo] self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device, @@ -93,7 +105,7 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try custom allreduce first, + # always try quick reduce first, then custom allreduce, # and then pynccl. qr_comm = self.qr_comm if qr_comm is not None and not qr_comm.disabled and \ diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 40300c797d48..e69d130f83f0 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -8,11 +8,12 @@ from torch.distributed import ProcessGroup from vllm import _custom_ops as ops +from vllm.platforms import current_platform logger = logging.getLogger(__name__) try: - ops.meta_size() + ops.qr_max_size() ops_available = True except Exception: # For CPUs @@ -22,15 +23,15 @@ class QuickReduceAlgo(Enum): OneShot = 0 TwoShot = 1 - TwoShot_FP8 = 2 - TwoShot_Q8 = 3 - TwoShot_Q6 = 4 - TwoShot_Q4 = 5 - TwoShot_MAX_MIN_Q8 = 6 + TwoShot_Q8 = 2 + TwoShot_Q6 = 3 + TwoShot_Q4 = 4 + TwoShot_MAX_MIN_Q8 = 5 class QuickAllReduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] def __init__(self, group: ProcessGroup, @@ -38,20 +39,22 @@ def __init__(self, algo: QuickReduceAlgo = QuickReduceAlgo.TwoShot) -> None: self.disabled = True if not ops_available: - # disable because of missing custom allreduce library + # disable because of missing quick reduce library # e.g. in a non-cuda environment + logger.info("Custom allreduce is disabled because " + "of missing custom allreduce library") return + self.max_size = ops.qr_max_size() self.group = group self.algo = algo assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") - rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) if world_size == 1: - # No need to initialize custom allreduce for single GPU case. + # No need to initialize QuickReduce for single GPU case. return if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: @@ -62,6 +65,9 @@ def __init__(self, world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) return + assert current_platform.is_rocm(), ( + "QuickReduce is only supported on ROCm platform.") + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): @@ -71,9 +77,9 @@ def __init__(self, self.device = device torch.cuda.set_device(self.device) - self.disabled = False self._ptr = ops.init_custom_qr(rank, world_size) self.create_shared_buffer() + self.disabled = False def create_shared_buffer(self): """ @@ -87,11 +93,7 @@ def create_shared_buffer(self): ops.qr_open_handles(self._ptr, handles) def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): - """Performs an out-of-place all reduce. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. + """Performs an out-of-place all reduce. """ inp_size = inp.numel() * inp.element_size() if inp_size >= self.max_size: @@ -103,13 +105,6 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): ops.qr_all_reduce(self._ptr, inp, out, self.algo.value) return out - def is_enabled(self): - return not self.disabled - - @staticmethod - def is_available(): - return ops_available - def close(self): if not self.disabled and getattr(self, "_ptr", None): ops.qr_destroy(self._ptr) @@ -122,7 +117,8 @@ def should_quick_allreduce(self, inp: torch.Tensor): if self.disabled: return False inp_size = inp.numel() * inp.element_size() - # custom allreduce requires input byte size to be multiples of 16 + # QuickReduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: return False - return inp.dtype == torch.float16 and inp_size < self.max_size + return inp.dtype in QuickAllReduce._SUPPORTED_DTYPES and \ + inp_size < self.max_size From 7099c3beff34c47d389ed072341a14c35b50de62 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 15 May 2025 12:03:14 +0000 Subject: [PATCH 04/28] WIP Signed-off-by: ilmarkov --- csrc/quickreduce/quick_reduce.h | 12 +- csrc/quickreduce/quick_reduce_impl.cuh | 121 ++++++++---------- .../device_communicators/quick_all_reduce.py | 6 +- 3 files changed, 60 insertions(+), 79 deletions(-) diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 6f24f139d705..b294cf5421b9 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -22,9 +22,9 @@ enum QuickReduceAlgo { ONESHOT_FP16 = 0, TWOSHOT_FP16 = 1, TWOSHOT_Q8 = 2, - TWOSHOT_Q6 = 3, - TWOSHOT_Q4 = 4, - TWOSHOT_MAX_MIN_Q8 = 5, + TWOSHOT_Q4 = 3, + TWOSHOT_MAX_MIN_Q8 = 4, + TWOSHOT_MAX_MIN_Q4 = 5, }; // ============================================================ @@ -200,12 +200,12 @@ struct DeviceComms { case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q8: TWOSHOT_DISPATCH(TwoshotMaxMinQ8LineCodec) break; - case QuickReduceAlgo::TWOSHOT_Q6: - TWOSHOT_DISPATCH(TwoshotQ6LineCodec) - break; case QuickReduceAlgo::TWOSHOT_Q4: TWOSHOT_DISPATCH(TwoshotQ4LineCodec) break; + case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q4: + TWOSHOT_DISPATCH(TwoshotMaxMinQ4LineCodec) + break; default: TWOSHOT_DISPATCH(TwoshotFP16LineCodec) break; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index b9715dc40823..dd7c254772a7 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -387,10 +387,10 @@ struct TwoshotQ4LineCodec { }; template -struct TwoshotQ6LineCodec { +struct TwoshotMaxMinQ4LineCodec { /* - Int6-blocking Line codec for Twoshot collectives. - We quantize the FP16/BF16 data to block-scaled Int64 in blocks of 32. + Int4-blocking Line codec for Twoshot collectives. + We quantize the FP16/BF16 data to block-scaled Int4 in blocks of 32. */ static int constexpr kAtoms = 8; @@ -399,11 +399,13 @@ struct TwoshotQ6LineCodec { // Codec tile size process by this workgroup. // Each threads processes a fragment of f16x8_t (16B), - // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + // into a int4x8_t (4B) and a 2 f16 scale shared among 32 values. static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 1664; - static int constexpr kRankTileQ2Offset = 1024; - static int constexpr kRankTileScaleOffset = 1536; + + // 1024 + 128 + 128 + static int constexpr kRankTileStride = 1280; + static int constexpr kRankTileScaleOffset = 1024; + static int constexpr kRankTileZeroOffset = 1152; static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; static int constexpr kRankBufferTileStride = @@ -413,30 +415,29 @@ struct TwoshotQ6LineCodec { static int constexpr kTileSize = kRankTileSize * kWorldSize; // Constants configuration - // {-1/32.0h, -1/32.0h}, f16x2_t + + // {-1/16.0h, -1/16.0h}, f16x2_t static int constexpr kScaleFactor = - std::is_same::value ? 0xA800A800 : 0xBD00BD00; + std::is_same::value ? 0xAC00AC00 : 0xBD80BD80; // {1e-7, 1e-7}, f16x2_t static int constexpr kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-32, -32}, fp16x2_t - static int constexpr kRangeMin = - std::is_same::value ? 0xD000D000 : 0xC200C200; + // {0, 0}, f16x2_t + static int constexpr kRangeMin = 0x00000000; - // {+31, +31}, fp16x2_t + // {+15, +15}, f16x2_t static int constexpr kRangeMax = - std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + std::is_same::value ? 0x4B804B80 : 0x41704170; - // {+32, +32}, int16x2_t - static int constexpr kRangeBias = 0x00200020; + static unsigned char constexpr kMask0F = 0x0F; int const thread; int const rank; int const group_leader; - __device_inline__ TwoshotQ6LineCodec(int thread, int rank) + __device_inline__ TwoshotMaxMinQ4LineCodec(int thread, int rank) : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { static_assert(kRankTileSize % 16 == 0, "kRankTileSize must be 16B aligned."); @@ -449,7 +450,7 @@ struct TwoshotQ6LineCodec { int32x4_t const atom = data[k]; // max(w), min(w) - int wmax, wmin, wblockmax; + int wmax, wmin, wblockmax, wblockmin; { int a, b; a = packed_max(atom[0], atom[1]); @@ -460,7 +461,7 @@ struct TwoshotQ6LineCodec { b = packed_min(atom[2], atom[3]); wmin = packed_min(a, b); - // Reduce the max among a group of 8 threads + // Reduce the max and min among a group of 8 threads // Note: This is basically 2 blocks of 32 values setup as the // upper/lower halves of the fp16x2_t for (int i = 1; i < 8; i <<= 1) { @@ -470,50 +471,39 @@ struct TwoshotQ6LineCodec { int y = __shfl_down(wmin, i); wmin = packed_min(wmin, y); } - wblockmax = packed_abs_max(wmax, wmin); // Share with the cohort - wblockmax = __shfl(wblockmax, group_leader); + wblockmax = __shfl(wmax, group_leader); + wblockmin = __shfl(wmin, group_leader); } - // Derive scales + // Derive zeros and scales + int decoding_zero = wblockmin; int decoding_scale; int encoding_scale; - decoding_scale = packed_mul(wblockmax, kScaleFactor); + + decoding_scale = packed_sub(wblockmax, decoding_zero); + decoding_scale = packed_mul(decoding_scale, kScaleFactor); encoding_scale = packed_add(decoding_scale, kScaleEpsilon); encoding_scale = packed_rcp(encoding_scale); // Apply scales to get quantized values int32x4_t w; for (int i = 0; i < 4; i++) { - w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_sub(atom[i], decoding_zero); + w[i] = packed_mul(w[i], encoding_scale); w[i] = packed_max(w[i], kRangeMin); w[i] = packed_min(w[i], kRangeMax); } - // Convert from fp16x2_t to uint16x2_t - int32x4_t q; + // Convert from f16x2_t to uint16x2_t + int32_t qw = 0; { - int16_t* qi = reinterpret_cast(&q); + unsigned char* qi = reinterpret_cast(&qw); T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); - - for (int i = 0; i < 4; i++) { - q[i] = packed_add(q[i], kRangeBias); - } - } - - // Pack 8 x q6 into int32_t + int16_t - uint32_t q4w; - uint16_t q2w = 0; - q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | - ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); - { - int16_t* tw = reinterpret_cast(&q); -#pragma unroll for (int i = 0; i < 8; i++) { - q2w |= (tw[i] >> 4) << (i * 2); + auto val = (unsigned char)T2float_cast(wh[i]) & kMask0F; + qi[i / 2] |= val << (4 * (i & 1)); } } @@ -521,16 +511,16 @@ struct TwoshotQ6LineCodec { // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; - uint16_t* q2w_ptr = - reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + int* qz_ptr = + reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - __builtin_nontemporal_store(q4w, q4w_ptr); - __builtin_nontemporal_store(q2w, q2w_ptr); + __builtin_nontemporal_store(qw, qw_ptr); if (threadIdx.x == group_leader) { __builtin_nontemporal_store(decoding_scale, qs_ptr); + __builtin_nontemporal_store(decoding_zero, qz_ptr); } } } @@ -540,44 +530,35 @@ struct TwoshotQ6LineCodec { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; - uint16_t* q2w_ptr = - reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + int* qz_ptr = + reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); - uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int32_t qw = __builtin_nontemporal_load(qw_ptr); int qs = __builtin_nontemporal_load(qs_ptr); + int qz = __builtin_nontemporal_load(qz_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q6 into fp16x8_t + // Unpack 8xq4 into f16x8_t int32x4_t w; { - static uint constexpr kMask000F = 0x000F000F; - // {1024.0, 1024.0}, f16x2_t - static uint constexpr kf162_1024 = - std::is_same::value ? 0x64006400 : 0x44804480; - // {-1056.0, -1056.0}, f16x2_t - static uint constexpr kF162_1056 = - std::is_same::value ? 0xE420E420 : 0xC484C484; + T* wh = reinterpret_cast(&w); + unsigned char* qi = reinterpret_cast(&qw); #pragma unroll - for (int i = 0; i < 4; i++) { - int32_t q4 = q4w & kMask000F; - int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); - q4w >>= 4; - q2w >>= 4; - - int32_t q6 = q4 | (q2 << 4) | kf162_1024; - w[i] = packed_add(q6, kF162_1056); + for (int i = 0; i < 8; i++) { + auto val = (qi[i / 2] >> (4 * (i & 1))) & kMask0F; + wh[i] = float2T_cast((float)val); } } // Apply decoding scales for (int i = 0; i < 4; i++) { w[i] = packed_mul(w[i], qs); + w[i] = packed_add(w[i], qz); } // That's pretty much it... diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index e69d130f83f0..70828b461ddb 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -24,9 +24,9 @@ class QuickReduceAlgo(Enum): OneShot = 0 TwoShot = 1 TwoShot_Q8 = 2 - TwoShot_Q6 = 3 - TwoShot_Q4 = 4 - TwoShot_MAX_MIN_Q8 = 5 + TwoShot_Q4 = 3 + TwoShot_MAX_MIN_Q8 = 4 + TwoShot_MAX_MIN_Q4 = 5 class QuickAllReduce: From 982349c351e913587212999d96727575c1aa3285 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 19 May 2025 09:41:54 +0000 Subject: [PATCH 05/28] Refactor QuickReduce Signed-off-by: ilmarkov --- csrc/custom_quickreduce.cu | 5 + csrc/ops.h | 1 + csrc/quickreduce/base.h | 94 ++- csrc/quickreduce/quick_reduce.h | 10 +- csrc/quickreduce/quick_reduce_impl.cuh | 796 +++++++----------- csrc/torch_bindings.cpp | 6 + tests/distributed/test_quick_reduce.py | 129 +++ vllm/_custom_ops.py | 4 + vllm/config.py | 9 + .../device_communicators/cuda_communicator.py | 35 +- .../device_communicators/quick_all_reduce.py | 32 +- vllm/distributed/parallel_state.py | 10 + vllm/engine/arg_utils.py | 8 + vllm/entrypoints/llm.py | 4 + vllm/v1/worker/gpu_worker.py | 4 +- vllm/worker/worker.py | 3 +- 16 files changed, 630 insertions(+), 520 deletions(-) create mode 100644 tests/distributed/test_quick_reduce.py diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index e3b7a6f69d7b..75e226f0328a 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -79,4 +79,9 @@ int64_t qr_max_size() { return static_cast(quickreduce::DeviceComms::kMaxProblemSize); } +int64_t qr_min_size() { + return static_cast(quickreduce::kBlockSize * quickreduce::kAtoms * + sizeof(quickreduce::int32x4_t)); +} + #endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 2d92f9b2fcb6..fd9f3ac1d670 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -369,4 +369,5 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t algo_int); int64_t qr_max_size(); +int64_t qr_min_size(); #endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 3cfb7ae2cbb8..749cb1af0e69 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -44,14 +44,29 @@ using fp32x4_t = __attribute__((__vector_size__(4 * sizeof(float)))) float; using fp32x8_t = __attribute__((__vector_size__(8 * sizeof(float)))) float; using fp32x16_t = __attribute__((__vector_size__(16 * sizeof(float)))) float; -static int constexpr kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); // Standard CDNA wavefront size. -static int constexpr kWavefront = 64; +static constexpr int kWavefront = 64; // 256 thread, 4 wavefronts. static dim3 constexpr kBlock = {64, 4, 1}; +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + // Methods __device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) { @@ -82,6 +97,9 @@ __device_inline__ static void buffer_store_dwordx4( int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); +// Setting fp16 flag does not seem to have an effect for gfx942 +// The register offset has to be validated +// So we don't use it in Codecs for now. __device_inline__ static void set_fp16_ovfl(bool const value) { // short size = 0b00001; // Specifies the bit size to modify // const short offset = 0b10111; // Corrected offset to 23, which is the bit @@ -302,4 +320,76 @@ __device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } +template +__device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +template +__device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, + int& wblockmin) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + wmin = packed_min(a, b); + + // Reduce the max and min among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + + // Share with the cohort + wblockmax = __shfl(wmax, group_leader); + wblockmin = __shfl(wmin, group_leader); +} + +__device_inline__ void set_sync_flag(int* flag_ptr, int flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__device_inline__ void wait_sync_flag(int* flag_ptr, int flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + } // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index b294cf5421b9..178631dac7bc 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -195,19 +195,19 @@ struct DeviceComms { data_offset, flag_color); break; case QuickReduceAlgo::TWOSHOT_Q8: - TWOSHOT_DISPATCH(TwoshotQ8LineCodec) + TWOSHOT_DISPATCH(CodecQ8Symm) break; case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q8: - TWOSHOT_DISPATCH(TwoshotMaxMinQ8LineCodec) + TWOSHOT_DISPATCH(CodecQ8Asymm) break; case QuickReduceAlgo::TWOSHOT_Q4: - TWOSHOT_DISPATCH(TwoshotQ4LineCodec) + TWOSHOT_DISPATCH(CodecQ4Symm) break; case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q4: - TWOSHOT_DISPATCH(TwoshotMaxMinQ4LineCodec) + TWOSHOT_DISPATCH(CodecQ4Asymm) break; default: - TWOSHOT_DISPATCH(TwoshotFP16LineCodec) + TWOSHOT_DISPATCH(CodecFP16) break; } HIP_CHECK(cudaGetLastError()); diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index dd7c254772a7..54f5d209a507 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -4,191 +4,35 @@ #include "base.h" namespace quickreduce { -// ============================================================ -// Oneshot -// ============================================================ -// MARK: Oneshot All Reduce -template -struct AllReduceOneshot { - // Fixed magic implementation. - // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of - // work. - static_assert(sizeof(T) == 2); - static int constexpr kAtoms = 8; - - // Size and atom stride of data that the workgroup will process. - static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); - static int constexpr kAtomStride = 256; - - __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - int const N, // number of elements - int const block, // this block's index - int const num_blocks, // total number of blocks - int const world_size, // total number of ranks - int const rank, // this rank's index - uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - int flag_color // Flag color for the network barrier - ) { - // Topology - int thread = threadIdx.x + threadIdx.y * kWavefront; - - long data_stride = num_blocks * kTileSize; - long flags_stride = num_blocks * sizeof(int); - - uint8_t* rank_buffer = buffer_list[rank]; - - // -------------------------------------------------------- - // Read A into registers - int32x4_t tA[kAtoms]; - - BufferResource src_buffer(const_cast(A), N * sizeof(T)); - - int src_offset = block * kTileSize + thread * sizeof(int32x4_t); - - for (int i = 0; i < kAtoms; i++) { - tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); - src_offset += kAtomStride * sizeof(int32x4_t); - } - - // -------------------------------------------------------- - // Write rank data into this rank segment of every rank's communication - // buffer. - long comm_data_offset = - data_offset + rank * data_stride + block * kTileSize; - long comm_flags_offset = rank * flags_stride + block * sizeof(int); - - if (thread < world_size) { - int r = thread; - int* flag_ptr = - reinterpret_cast(buffer_list[r] + comm_flags_offset); - while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color - 1) { - } - } - __syncthreads(); - - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = - reinterpret_cast(buffer_list[r] + comm_data_offset); - for (int i = 0; i < kAtoms; i++) { - __builtin_nontemporal_store(tA[i], send_buffer + thread); - send_buffer += kAtomStride; - } - } - - // Inform the other ranks that th data has been posted. - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* flag_ptr = - reinterpret_cast(buffer_list[r] + comm_flags_offset); - __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); - } - - // -------------------------------------------------------- - // Read and reduce the data from this rank's communication buffer. - int32x4_t tB[kAtoms]; - - { - int r = 0; - - // Wait for the flags to be set. - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - if (thread == 0) { - while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { - } - } - __syncthreads(); - - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + block * kTileSize); - - for (int i = 0; i < kAtoms; i++) { - tB[i] = __builtin_nontemporal_load(recv_buffer + thread); - recv_buffer += kAtomStride; - } - } - - for (int r = 1; r < world_size; r++) { - // Wait for the flags to be set. - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - if (thread == 0) { - while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag_color) { - } - } - __syncthreads(); - - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + block * kTileSize); - - for (int i = 0; i < kAtoms; i++) { - tA[i] = __builtin_nontemporal_load(recv_buffer + thread); - recv_buffer += kAtomStride; - } - - // Reduce. - for (int i = 0; i < kAtoms; i++) { - packed_assign_add(&tB[i], &tA[i]); - } - } - - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELAXED); - } - - // -------------------------------------------------------- - // Write the result to B. - BufferResource dst_buffer(B, N * sizeof(T)); - int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); - for (int i = 0; i < kAtoms; i++) { - buffer_store_dwordx4(tB[i], dst_buffer.descriptor, dst_offset, 0, 0); - dst_offset += kAtomStride * sizeof(int32x4_t); - } - } +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __device_inline__ CodecBase(int thread, int rank) + : thread(thread), + rank(rank), + group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {} }; -// ============================================================ -// Twoshot -// ============================================================ -// MARK: FP16 Line Codec +// Default full precision codec. template -struct TwoshotFP16LineCodec { - /* - Default FP16 line codec for Twoshot collectives. - No actual compression is involved. - */ - - static int constexpr kAtoms = 8; - static int constexpr kAtomStride = 256; - static int constexpr kWorldSize = world_size; +struct CodecFP16 : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; // Codec tile size process by this workgroup. - // Each thread processes atoms of fp16x8_t (16B). - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileSize = 256 * kRankAtoms * sizeof(int32x4_t); + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = + kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); // Total tile size for the collective communication. - static int constexpr kTileSize = kRankTileSize * kWorldSize; + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; - int const thread; - int const rank; - - __device_inline__ TwoshotFP16LineCodec(int thread, int rank) - : thread(thread), rank(rank) { - static_assert(kRankTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); - } + __device_inline__ CodecFP16(int thread, int rank) : CodecBase(thread, rank) {} __device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) { @@ -207,62 +51,55 @@ struct TwoshotFP16LineCodec { } }; -// MARK: Q4 Line Codec +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. template -struct TwoshotQ4LineCodec { - /* - Int4-blocking Line codec for Twoshot collectives. - We quantize the FP16 data to block-scaled Int4 in blocks of 32. - */ - - static int constexpr kAtoms = 8; - static int constexpr kAtomStride = 256; - static int constexpr kWorldSize = world_size; +struct CodecQ4Symm : public CodecBase { + static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. // Each threads processes a fragment of fp16x8_t (16B), // into a int4x8_t (4B) and a fp16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 1152; - static int constexpr kRankTileScaleOffset = 1024; - static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; - - static int constexpr kRankBufferTileStride = + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTileSize = kRankTileSize * kWorldSize; + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; // Constants configuration // {-1/8.0h, -1/8.0h}, f16x2_t - static int constexpr kScaleFactor = + static constexpr int kScaleFactor = std::is_same::value ? 0xB000B000 : 0xBE00BE00; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; // {-8, -8}, f16x2_t - static int constexpr kRangeMin = + static constexpr int kRangeMin = std::is_same::value ? 0xC800C800 : 0xC100C100; // {+7, +7}, f16x2_t - static int constexpr kRangeMax = + static constexpr int kRangeMax = std::is_same::value ? 0x47004700 : 0x40E040E0; // {+8, +8}, int16x2_t - static int constexpr kRangeBias = 0x00080008; - - int const thread; - int const rank; - int const group_leader; + static constexpr int kRangeBias = 0x00080008; - __device_inline__ TwoshotQ4LineCodec(int thread, int rank) - : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { - static_assert(kRankTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); - set_fp16_ovfl(true); + __device_inline__ CodecQ4Symm(int thread, int rank) + : CodecBase(thread, rank) { + // if constexpr (std::is_same::value) + // set_fp16_ovfl(true); } __device_inline__ void send(int32x4_t* __restrict__ send_buffer, @@ -270,33 +107,9 @@ struct TwoshotQ4LineCodec { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; - // max(w), min(w) - int wmax, wmin, wblockmax; - { - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - wmax = packed_max(a, b); - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - wmin = packed_min(a, b); - - // Reduce the max among a group of 8 threads - // Note: This is basically 2 blocks of 32 values setup as the - // upper/lower halves of the fp16x2_t - for (int i = 1; i < 8; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } - wblockmax = packed_abs_max(wmax, wmin); - - // Share with the cohort - wblockmax = __shfl(wblockmax, group_leader); - } + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); // Derive scales int decoding_scale; @@ -361,12 +174,12 @@ struct TwoshotQ4LineCodec { // Unpack q4 into fp16x8_t int32x4_t w; { - static uint constexpr kMask000F = 0x000F000F; + static constexpr uint kMask000F = 0x000F000F; // {1024.0, 1024.0}, f16x2_t - static uint constexpr kF162_1024 = + static constexpr uint kF162_1024 = std::is_same::value ? 0x64006400 : 0x44804480; // {-1032.0, -1032.0}, f16x2_t - static uint constexpr kF162_1032 = + static constexpr uint kF162_1032 = std::is_same::value ? 0xE408E408 : 0xC481C481; for (int i = 0; i < 4; i++) { @@ -380,68 +193,60 @@ struct TwoshotQ4LineCodec { w[i] = packed_mul(w[i], qs); } - // That's pretty much it... data[k] = w; } } }; +// Int4 Asymm quantization codec. +// We quantize the FP16/BF16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. template -struct TwoshotMaxMinQ4LineCodec { - /* - Int4-blocking Line codec for Twoshot collectives. - We quantize the FP16/BF16 data to block-scaled Int4 in blocks of 32. - */ +struct CodecQ4Asymm : public CodecBase { + static constexpr int kWorldSize = world_size; - static int constexpr kAtoms = 8; - static int constexpr kAtomStride = 256; - static int constexpr kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; // Codec tile size process by this workgroup. - // Each threads processes a fragment of f16x8_t (16B), - // into a int4x8_t (4B) and a 2 f16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - - // 1024 + 128 + 128 - static int constexpr kRankTileStride = 1280; - static int constexpr kRankTileScaleOffset = 1024; - static int constexpr kRankTileZeroOffset = 1152; - static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; - - static int constexpr kRankBufferTileStride = + // Each thread processes a fragment of 4xf16x2_t (16B) values, + // into a int4x8_t (4B), 2 zeros and 2 scales shared by thread group (4 * + // kThreadGroupSize) values. 1024 + 128 + 128 + static constexpr int kRankTileStride = 1280; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTileZeroOffset = 1152; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); - // Total tile size for the collective communication. - static int constexpr kTileSize = kRankTileSize * kWorldSize; + // Total size of transmitted tile. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; // Constants configuration // {-1/16.0h, -1/16.0h}, f16x2_t - static int constexpr kScaleFactor = + static constexpr int kScaleFactor = std::is_same::value ? 0xAC00AC00 : 0xBD80BD80; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; // {0, 0}, f16x2_t - static int constexpr kRangeMin = 0x00000000; + static constexpr int kRangeMin = 0x00000000; // {+15, +15}, f16x2_t - static int constexpr kRangeMax = + static constexpr int kRangeMax = std::is_same::value ? 0x4B804B80 : 0x41704170; static unsigned char constexpr kMask0F = 0x0F; - - int const thread; - int const rank; - int const group_leader; - - __device_inline__ TwoshotMaxMinQ4LineCodec(int thread, int rank) - : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { - static_assert(kRankTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); - set_fp16_ovfl(true); + __device_inline__ CodecQ4Asymm(int thread, int rank) + : CodecBase(thread, rank) { + // if constexpr (std::is_same::value) + // set_fp16_ovfl(true); } __device_inline__ void send(int32x4_t* __restrict__ send_buffer, @@ -449,33 +254,10 @@ struct TwoshotMaxMinQ4LineCodec { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; - // max(w), min(w) - int wmax, wmin, wblockmax, wblockmin; - { - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - wmax = packed_max(a, b); - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - wmin = packed_min(a, b); - - // Reduce the max and min among a group of 8 threads - // Note: This is basically 2 blocks of 32 values setup as the - // upper/lower halves of the fp16x2_t - for (int i = 1; i < 8; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } - - // Share with the cohort - wblockmax = __shfl(wmax, group_leader); - wblockmin = __shfl(wmin, group_leader); - } + int wblockmax, wblockmin; + // Find max/min in thread group. + // In 2 blocks of values, upper/lower halves of the f16x2_t + group_max_min(atom, wblockmax, wblockmin); // Derive zeros and scales int decoding_zero = wblockmin; @@ -496,7 +278,7 @@ struct TwoshotMaxMinQ4LineCodec { w[i] = packed_min(w[i], kRangeMax); } - // Convert from f16x2_t to uint16x2_t + // Convert from f16 to a byte and pack into 4 bits. int32_t qw = 0; { unsigned char* qi = reinterpret_cast(&qw); @@ -555,73 +337,64 @@ struct TwoshotMaxMinQ4LineCodec { } } - // Apply decoding scales + // Apply decoding scales and zeros for (int i = 0; i < 4; i++) { w[i] = packed_mul(w[i], qs); w[i] = packed_add(w[i], qz); } - // That's pretty much it... data[k] = w; } } }; -// MARK: Q8 Line Codec +// Int8-blocking Line codec for Twoshot collectives. +// We quantize the FP16/BF16 data to block-scaled Int8 in blocks of 32. template -struct TwoshotQ8LineCodec { - /* - Int8-blocking Line codec for Twoshot collectives. - We quantize the FP16/BF16 data to block-scaled Int8 in blocks of 32. - */ - - static int constexpr kAtoms = 8; - static int constexpr kAtomStride = 256; - static int constexpr kWorldSize = world_size; +struct CodecQ8Symm : public CodecBase { + static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. // Each threads processes a fragment of f16x8_t (16B), // into a int8x8_t (8B) and a f16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 2176; - static int constexpr kRankTileScaleOffset = 2048; - static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; - - static int constexpr kRankBufferTileStride = + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTileSize = kRankTileSize * kWorldSize; + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; // Constants configuration // {-1/128.0h, -1/128.0h}, f16x2_t - static int constexpr kScaleFactor = + static constexpr int kScaleFactor = std::is_same::value ? 0xA000A000 : 0xBC00BC00; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; // {-128, -128}, f16x2_t - static int constexpr kRangeMin = + static constexpr int kRangeMin = std::is_same::value ? 0xD800D800 : 0xC300C300; // {+127, +127}, f16x2_t - static int constexpr kRangeMax = + static constexpr int kRangeMax = std::is_same::value ? 0x57F057F0 : 0x42FE42FE; // {+128, +128}, int16x2_t - static int constexpr kRangeBias = 0x00800080; + static constexpr int kRangeBias = 0x00800080; - int const thread; - int const rank; - int const group_leader; - - __device_inline__ TwoshotQ8LineCodec(int thread, int rank) - : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { - static_assert(kRankTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); - set_fp16_ovfl(true); + __device_inline__ CodecQ8Symm(int thread, int rank) + : CodecBase(thread, rank) { + // if constexpr (std::is_same::value) + // set_fp16_ovfl(true); } __device_inline__ void send(int32x4_t* __restrict__ send_buffer, @@ -629,36 +402,8 @@ struct TwoshotQ8LineCodec { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; - // max(w), min(w) - int wmax, wmin, wblockmax; - { - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - - wmax = packed_max(a, b); - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - - wmin = packed_min(a, b); - - // Reduce the max among a group of 8 threads - // Note: This is basically 2 blocks of 32 values setup as the - // upper/lower halves of the fp16x2_t - for (int i = 1; i < 8; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } - wblockmax = packed_abs_max(wmax, wmin); - - // Share with the cohort - wblockmax = __shfl(wblockmax, group_leader); - } - + // Find abs max in thread group + int wblockmax = group_abs_max(atom); // Derive scales int decoding_scale; int encoding_scale; @@ -692,15 +437,15 @@ struct TwoshotQ8LineCodec { qw[0] = q[0] | (q[1] << 8); qw[1] = q[2] | (q[3] << 8); - // Write quantized atom to send_buffer - // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + // Write quantized atom to send_buffer __builtin_nontemporal_store(qw, qw_ptr); + // note: only the group leader stores the scale if (threadIdx.x == group_leader) { __builtin_nontemporal_store(decoding_scale, qs_ptr); } @@ -708,7 +453,7 @@ struct TwoshotQ8LineCodec { } __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { + int32x4_t* __restrict__ output) { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); @@ -724,14 +469,14 @@ struct TwoshotQ8LineCodec { // Unpack q8 into fp16x8_t int32x4_t w; { - static uint constexpr kMask00FF = 0x00FF00FF; + static constexpr uint kMask00FF = 0x00FF00FF; // {1024.0, 1024.0}, f16x2_t - static uint constexpr kF162_1024 = + static constexpr uint kF162_1024 = std::is_same::value ? 0x64006400 : 0x44804480; // {-1152.0, -1152.0}, f16x2_t - static uint constexpr kF162_1152 = + static constexpr uint kF162_1152 = std::is_same::value ? 0xE480E480 : 0xC490C490; #pragma unroll @@ -745,102 +490,71 @@ struct TwoshotQ8LineCodec { for (int i = 0; i < 4; i++) { w[i] = packed_mul(w[i], qs); } - - // That's pretty much it... - data[k] = w; + output[k] = w; } } }; +// Int8-blocking Line codec for Twoshot collectives. +// We quantize the FP16 data to block-scaled Int8 in blocks of 32. template -struct TwoshotMaxMinQ8LineCodec { - /* - Int8-blocking Line codec for Twoshot collectives. - We quantize the FP16 data to block-scaled Int8 in blocks of 32. - */ - - static int constexpr kAtoms = 8; - static int constexpr kAtomStride = 256; - static int constexpr kWorldSize = world_size; +struct CodecQ8Asymm : public CodecBase { + static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. // Each thread processes a fragment of fp16x8_t (16B), // into a int8x8_t (8B) and a fp16 zero and a fp16 scale shared among 32 // values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankAtoms = kAtoms / kWorldSize; // 2048 + 128 + 128 - static int constexpr kRankTileStride = 2304; - static int constexpr kRankTileScaleOffset = 2048; - static int constexpr kRankTileZeroOffset = 2176; - static int constexpr kRankTileSize = kRankTileStride * kRankAtoms; - - static int constexpr kRankBufferTileStride = + static constexpr int kRankTileStride = 2304; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTileZeroOffset = 2176; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTileSize = kRankTileSize * kWorldSize; + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; // Constants configuration // {1/255.0h, 1/255.0h}, f16x2_t - static int constexpr kScaleFactor = + static constexpr int kScaleFactor = std::is_same::value ? 0x1C041C04 : 0x3B813B81; // {1e-7, 1e-7}, fp16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; // {0, 0}, f16x2_t - static int constexpr kRangeMin = 0x00000000; + static constexpr int kRangeMin = 0x00000000; // {+255, +255}, f16x2_t - static int constexpr kRangeMax = + static constexpr int kRangeMax = std::is_same::value ? 0x5BF85BF8 : 0x437F437F; // {+128, +128}, int16x2_t - static int constexpr kRangeBias = 0x00800080; + static constexpr int kRangeBias = 0x00800080; - int const thread; - int const rank; - int const group_leader; - - __device_inline__ TwoshotMaxMinQ8LineCodec(int thread, int rank) - : thread(thread), rank(rank), group_leader((threadIdx.x / 8) * 8) { - static_assert(kRankTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); - set_fp16_ovfl(true); + __device_inline__ CodecQ8Asymm(int thread, int rank) + : CodecBase(thread, rank) { + // if constexpr (std::is_same::value) + // set_fp16_ovfl(true); } __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + int32x4_t const* __restrict__ input) { for (int k = 0; k < kRankAtoms; k++) { - int32x4_t const atom = data[k]; - // max(w), min(w) - int wmax, wmin, wblockmax, wblockmin; - { - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - wmax = packed_max(a, b); - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - wmin = packed_min(a, b); - - // Reduce the max among a group of 8 threads - // Note: This is basically 2 blocks of 32 values setup as the - // upper/lower halves of the fp16x2_t - for (int i = 1; i < 8; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } + int32x4_t const atom = input[k]; - // Share with the cohort - wblockmax = __shfl(wmax, group_leader); - wblockmin = __shfl(wmin, group_leader); - } + int wblockmax, wblockmin; + // Find max/min in thread group. + // In 2 blocks of values, upper/lower halves of the f16x2_t + group_max_min(atom, wblockmax, wblockmin); // Derive zeros and scales int decoding_zero = wblockmin; @@ -927,44 +641,173 @@ struct TwoshotMaxMinQ8LineCodec { } }; -// MARK: Twoshot All Reduce -template -struct AllReduceTwoshot { - // Fixed magic implementation. - // We will use a workgroup of 256 threads (standard kBlock) across 8 atoms of - // work. - static int constexpr kAtoms = 8; +// Oneshot AllReduce +template +struct AllReduceOneshot { + static_assert(sizeof(T) == 2); + + __device__ static void run( + T const* __restrict__ input, // input + T* __restrict__ output, // output + int const N, // number of elements + int const block, // this block's index + int const num_blocks, // total number of blocks + int const world_size, // total number of ranks + int const rank, // this rank's index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + int flag_color // Flag color for the network barrier + ) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + + long data_stride = num_blocks * kTileSize; + long flags_stride = num_blocks * sizeof(int); + + uint8_t* rank_buffer = buffer_list[rank]; + + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + + int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Write rank data into this rank segment of every rank's communication + // buffer. + long comm_data_offset = + data_offset + rank * data_stride + block * kTileSize; + long comm_flags_offset = rank * flags_stride + block * sizeof(int); + + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + wait_sync_flag(flag_ptr, flag_color - 1); + } + __syncthreads(); + + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data_offset); + for (int i = 0; i < kAtoms; i++) { + __builtin_nontemporal_store(tA[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + // Inform the other ranks that th data has been posted. + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = + reinterpret_cast(buffer_list[r] + comm_flags_offset); + set_sync_flag(flag_ptr, flag_color); + } + + // Read and reduce the data from this rank's communication buffer. + int32x4_t tB[kAtoms]; + + { + int r = 0; + + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + wait_sync_flag(flag_ptr, flag_color); + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); + + for (int i = 0; i < kAtoms; i++) { + tB[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } + } + + for (int r = 1; r < world_size; r++) { + // Wait for the flags to be set. + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + if (thread == 0) { + wait_sync_flag(flag_ptr, flag_color); + } + __syncthreads(); + + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + block * kTileSize); - // Size and atom stride of source/destination data that the workgroup will - // process. - static int constexpr kTileSize = 256 * kAtoms * sizeof(int32x4_t); - static int constexpr kAtomStride = 256; + for (int i = 0; i < kAtoms; i++) { + tA[i] = __builtin_nontemporal_load(recv_buffer + thread); + recv_buffer += kAtomStride; + } - static int constexpr kWorldSize = LineCodec::kWorldSize; + // Reduce. + for (int i = 0; i < kAtoms; i++) { + packed_assign_add(&tB[i], &tA[i]); + } + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + + block * sizeof(int)); + set_sync_flag(flag_ptr, flag_color); + } + + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tB[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - int const N, // number of elements - int const block, // block index - int const num_blocks, // number of blocks - int const world_size, // unused - only kept around for API consistency - int const rank, // rank index + T const* __restrict__ input, T* __restrict__ output, + int const N, // number of elements + int const block, // block index + int const num_blocks, // number of blocks + int const world_size, // unused - only kept around for API consistency + int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers long const data_offset, // offset to start of the data buffer int flag_color) { // Topology int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; - LineCodec codec(thread, rank); + Codec codec(thread, rank); // -------------------------------------------------------- - // Read A into registers + // Read input into registers int32x4_t tA[kAtoms]; - BufferResource src_buffer(const_cast(A), N * sizeof(T)); + BufferResource src_buffer(const_cast(input), N * sizeof(T)); int src_offset = block * kTileSize + thread * sizeof(int32x4_t); - int32x4_t* src = reinterpret_cast(const_cast(A)); + int32x4_t* src = reinterpret_cast(const_cast(input)); for (int i = 0; i < kAtoms; i++) { tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); @@ -974,18 +817,19 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Phase-1A: Write segment data into the communication buffer of the target // rank responsible for this segment. - long comm_data0_offset = data_offset + block * LineCodec::kTileSize; + long comm_data0_offset = data_offset + block * Codec::kTransmittedTileSize; long comm_data1_offset = - num_blocks * LineCodec::kTileSize + comm_data0_offset; + num_blocks * Codec::kTransmittedTileSize + comm_data0_offset; long comm_flags0_offset = block * (kWorldSize * sizeof(int)); long comm_flags1_offset = num_blocks * (kWorldSize * sizeof(int)) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + comm_data0_offset + rank * LineCodec::kRankTileSize); - codec.send(send_buffer, &tA[r * LineCodec::kRankAtoms]); + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); } __syncthreads(); @@ -993,11 +837,11 @@ struct AllReduceTwoshot { int r = thread; int* flag_ptr = reinterpret_cast( buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); - __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + set_sync_flag(flag_ptr, flag_color); } // -------------------------------------------------------- // Phase-1B: Reduce the segment data from the communication buffers. - int32x4_t tR[LineCodec::kRankAtoms] = {}; + int32x4_t tR[Codec::kRankAtoms] = {}; { // Read the data from the communication buffer. int32x4_t* recv_buffer = @@ -1007,27 +851,24 @@ struct AllReduceTwoshot { for (int r = 0; r < kWorldSize; r++) { // Wait for the flags to be set. if (thread == 0) { - while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != - flag_color) { - } + wait_sync_flag(&flag_ptr[r], flag_color); } __syncthreads(); // note: we reuse tA as temp buffer here codec.recv(&recv_buffer, tA); - for (int i = 0; i < LineCodec::kRankAtoms; i++) { + for (int i = 0; i < Codec::kRankAtoms; i++) { packed_assign_add(&tR[i], &tA[i]); } } } - // -------------------------------------------------------- // Phase-2: Write the reduced segment to every other rank - // This is basically an all-gather. for (int r = 0; r < kWorldSize; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + comm_data1_offset + rank * LineCodec::kRankTileSize); + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + + rank * Codec::kRankTransmittedTileSize); codec.send(send_buffer, tR); } @@ -1036,10 +877,9 @@ struct AllReduceTwoshot { int r = thread; int* flag_ptr = reinterpret_cast( buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); - __atomic_store_n(flag_ptr, flag_color, __ATOMIC_RELEASE); + set_sync_flag(flag_ptr, flag_color); } - // -------------------------------------------------------- // Phase-2: Read the gather segments from the rank's communication buffer. { // Read the data from the communication buffer. @@ -1050,22 +890,20 @@ struct AllReduceTwoshot { for (int r = 0; r < kWorldSize; r++) { // Wait for the flags to be set. if (thread == 0) { - while (__atomic_load_n(&flag_ptr[r], __ATOMIC_RELAXED) != - flag_color) { - } + wait_sync_flag(&flag_ptr[r], flag_color); } __syncthreads(); // Gather all reduced and final rank segments into tA. - codec.recv(&recv_buffer, &tA[r * LineCodec::kRankAtoms]); + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); } } // -------------------------------------------------------- - // Write the result to B. - BufferResource dst_buffer(B, N * sizeof(T)); + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); - int32x4_t* dst = reinterpret_cast(B); + int32x4_t* dst = reinterpret_cast(output); for (int i = 0; i < kAtoms; i++) { buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 98f51c9cd88e..f2ccdab0cf8f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -726,6 +726,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("free_shared_buffer", &free_shared_buffer); #ifdef USE_ROCM + // Quick Reduce all-reduce kernels custom_ar.def( "qr_all_reduce(int fa, Tensor inp, Tensor out, int algo_int) -> ()"); custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); @@ -738,7 +739,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + // Max input size in bytes custom_ar.def("qr_max_size", &qr_max_size); + // Minimal input size in bytes. + // For inputs size less than this value + // AllReduce collective is ineffective + custom_ar.def("qr_min_size", &qr_min_size); #endif } diff --git a/tests/distributed/test_quick_reduce.py b/tests/distributed/test_quick_reduce.py new file mode 100644 index 000000000000..f4c213981f4e --- /dev/null +++ b/tests/distributed/test_quick_reduce.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random + +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickReduceAlgo) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, graph_capture, + set_quick_reduce_algo) + +from ..utils import init_test_distributed_environment, multi_process_parallel + +random.seed(42) +test_sizes = [random.randint(256 * 8 * 4, 2048 * 1024) for _ in range(8)] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +# Same as in custom all-reduce +# Only enable QuickReduce +@ray.remote(num_gpus=1, max_calls=1) +def graph_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + set_quick_reduce_algo(QuickReduceAlgo.TwoShotFP) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + + for sz in test_sizes: + for dtype in [torch.float16, torch.bfloat16]: + with graph_capture(device=device) as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + +@ray.remote(num_gpus=1, max_calls=1) +def eager_quick_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + set_quick_reduce_algo(QuickReduceAlgo.TwoShotFP) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + for dtype in [torch.float16, torch.bfloat16]: + + num_communication = rank // tp_size + 1 + sz = 256 * 8 * 8 + qr_comm = get_tp_group().device_communicator.qr_comm + inp = torch.ones(sz, dtype=dtype, device=device) + out = inp + for _ in range(num_communication): + out = qr_comm.all_reduce(out) + torch.testing.assert_close(out, inp * (tp_size**num_communication)) + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) +@pytest.mark.parametrize("test_target", + [eager_quick_allreduce, graph_allreduce]) +def test_quick_reduce_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, + test_target) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6ef4624b2dcd..8e44e306886c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1784,6 +1784,10 @@ def qr_max_size() -> int: return torch.ops._C_custom_ar.qr_max_size() +def qr_min_size() -> int: + return torch.ops._C_custom_ar.qr_min_size() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/config.py b/vllm/config.py index d986ab6b0edb..49eb71ad7f91 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,6 +33,8 @@ import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickReduceAlgo) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, QuantizationMethods, @@ -1770,6 +1772,13 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + quick_reduce_allreduce_algo: Optional[QuickReduceAlgo] = None + """Enable alternative to custom all-reduce. + Supports asymmetric and symmetric quantization (4 and 8 bits) + for 2 Shot algorithm. Only supported on AMD, + for bf16 and fp16 input data types. + """ + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None """This parameter is deprecated and will be removed in a future release. Please remove it from your configs""" diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 689b1262b2ca..55e9c68805c6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -2,12 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -import os from typing import Optional import torch from torch.distributed import ProcessGroup +import vllm.envs as envs from vllm.platforms import current_platform import vllm.envs as envs @@ -29,10 +29,12 @@ def __init__(self, if "tp" not in unique_name: # only tp uses custom allreduce use_custom_allreduce = False + quick_reduce_algo = None else: from vllm.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE) + _ENABLE_CUSTOM_ALL_REDUCE, _QUICK_REDUCE_ALGO) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + quick_reduce_algo = _QUICK_REDUCE_ALGO # ep does not use pynccl use_pynccl = "ep" not in unique_name @@ -46,7 +48,7 @@ def __init__(self, from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce, QuickReduceAlgo) + QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -62,26 +64,21 @@ def __init__(self, group=self.cpu_group, device=self.device, ) - self.use_quick_allreduce = os.environ.get("VLLM_USE_QUICK_ALLREDUCE", - "0") == "1" - if self.use_quick_allreduce and not current_platform.is_rocm(): + self.quick_reduce_comm_algo = quick_reduce_algo + if (self.quick_reduce_comm_algo is not None + and not current_platform.is_rocm()): logger.warning( - "Environment variable VLLM_USE_QUICK_ALLREDUCE is set to 1," \ - " but QuickReduce is only supported on ROCm platform." - ) + "quick_reduce_comm_algo is not None," + " but QuickReduce is only supported on ROCm platform.") self.qr_comm: Optional[QuickAllReduce] = None - if self.use_quick_allreduce and self.world_size > 1 and \ - current_platform.is_rocm(): - # Initialize a custom fast all-reduce implementation. - qr_comm_algo = os.environ.get("VLLM_QUICK_ALLREDUCE_ALGO", - "TwoShot") - assert qr_comm_algo in QuickReduceAlgo.__members__, \ - "Unknown QuickReduce algorithm: {}".format( - qr_comm_algo) - self.qr_comm_algo = QuickReduceAlgo[qr_comm_algo] + + if (self.quick_reduce_comm_algo is not None + and current_platform.is_rocm() and self.world_size > 1): + # Initialize a custom fast all-reduce implementation + # based on quick reduce (https://github.com/mk1-project/quickreduce). self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device, - algo=self.qr_comm_algo) + algo=self.quick_reduce_comm_algo) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 70828b461ddb..686444eb4237 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -21,12 +21,12 @@ class QuickReduceAlgo(Enum): - OneShot = 0 - TwoShot = 1 - TwoShot_Q8 = 2 - TwoShot_Q4 = 3 - TwoShot_MAX_MIN_Q8 = 4 - TwoShot_MAX_MIN_Q4 = 5 + OneShotFP = 0 + TwoShotFP = 1 + TwoShotQ8SYMM = 2 + TwoShotQ4SYMM = 3 + TwoShotQ8ASYMM = 4 + TwoShotQ4ASYMM = 5 class QuickAllReduce: @@ -36,7 +36,7 @@ class QuickAllReduce: def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - algo: QuickReduceAlgo = QuickReduceAlgo.TwoShot) -> None: + algo: QuickReduceAlgo = QuickReduceAlgo.TwoShotFP) -> None: self.disabled = True if not ops_available: # disable because of missing quick reduce library @@ -46,8 +46,14 @@ def __init__(self, return self.max_size = ops.qr_max_size() + self.min_size = ops.qr_min_size() self.group = group - self.algo = algo + if isinstance(algo, str): + assert algo in QuickReduceAlgo.__members__, \ + "Algo {} is not supported by QuickReduce".format(algo) + self.algo = QuickReduceAlgo[algo] + else: + self.algo = algo assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") @@ -60,14 +66,14 @@ def __init__(self, if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: logger.warning( "QuickReduce allreduce is disabled due to an unsupported world" - " size: %d. Supported world sizes: %s. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.", - world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + " size: %d. Supported world sizes: %s." + " To disable this warning set quick_reduce_allreduce_algo" + " to None", world_size, + str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) return assert current_platform.is_rocm(), ( "QuickReduce is only supported on ROCm platform.") - if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): @@ -121,4 +127,4 @@ def should_quick_allreduce(self, inp: torch.Tensor): if inp_size % 16 != 0: return False return inp.dtype in QuickAllReduce._SUPPORTED_DTYPES and \ - inp_size < self.max_size + inp_size < self.max_size and inp_size >= self.min_size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 126160b09553..c9af2aad722b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -40,6 +40,8 @@ import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) +from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickReduceAlgo) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, get_distributed_init_method, @@ -908,6 +910,14 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +_QUICK_REDUCE_ALGO: Optional[QuickReduceAlgo] = None + + +def set_quick_reduce_algo(algo: Optional[QuickReduceAlgo]): + global _QUICK_REDUCE_ALGO + _QUICK_REDUCE_ALGO = algo + + def init_distributed_environment( world_size: int = -1, rank: int = -1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f599d7a3bb5e..0a0c699f7713 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,6 +31,8 @@ SchedulerConfig, SchedulerPolicy, SpeculativeConfig, TaskOption, TokenizerMode, TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field) +from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickReduceAlgo) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods @@ -339,6 +341,8 @@ class EngineArgs: enforce_eager: bool = ModelConfig.enforce_eager max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce + quick_reduce_allreduce_algo: Optional[ + QuickReduceAlgo] = ParallelConfig.quick_reduce_allreduce_algo # The following three fields are deprecated and will be removed in a future # release. Setting them will have no effect. Please remove them from your # configurations. @@ -662,6 +666,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--disable-custom-all-reduce", **parallel_kwargs["disable_custom_all_reduce"]) + parallel_group.add_argument( + "--quick-reduce-allreduce-algo", + **parallel_kwargs["quick_reduce_allreduce_algo"]) parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) parallel_group.add_argument("--worker-extension-cls", @@ -1123,6 +1130,7 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, + quick_reduce_allreduce_algo=self.quick_reduce_allreduce_algo, ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c11e627ee236..9fe28cda4372 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,6 +18,8 @@ BeamSearchSequence, get_beam_search_score) from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, is_init_field) +from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickReduceAlgo) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -175,6 +177,7 @@ def __init__( enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + quick_reduce_allreduce_algo: Optional[QuickReduceAlgo] = None, disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, @@ -249,6 +252,7 @@ def __init__( enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + quick_reduce_allreduce_algo=quick_reduce_allreduce_algo, disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 58795e3fe292..594c203803a6 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce) + set_custom_all_reduce, set_quick_reduce_algo) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -369,6 +369,8 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + set_quick_reduce_algo(parallel_config.quick_reduce_allreduce_algo) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9a928632688a..68bcc8fc9376 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce) + set_custom_all_reduce, set_quick_reduce_algo) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -528,6 +528,7 @@ def init_worker_distributed_environment( """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + set_quick_reduce_algo(parallel_config.quick_reduce_allreduce_algo) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) From dbdbd2d77a0a78317695a96417ddd6b3927e1363 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 20 May 2025 13:37:57 +0000 Subject: [PATCH 06/28] Cleanup Signed-off-by: ilmarkov --- csrc/custom_quickreduce.cu | 8 +- csrc/quickreduce/base.h | 109 ++++++++++-------- csrc/quickreduce/quick_reduce.h | 36 ++---- csrc/quickreduce/quick_reduce_impl.cuh | 53 ++++----- .../device_communicators/quick_all_reduce.py | 8 +- 5 files changed, 106 insertions(+), 108 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 75e226f0328a..b3c8f9ffee1e 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -66,9 +66,11 @@ void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel()); } else if (out.scalar_type() == at::ScalarType::BFloat16) { - fa->allreduce( - algo_int, stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + fa->allreduce( + algo_int, stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); } else { throw std::runtime_error( "quick allreduce only supports float16 and bfloat16"); diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 749cb1af0e69..a54b1178a956 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -5,12 +5,13 @@ #include #include -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) + namespace quickreduce { -#define __device_inline__ __device__ __forceinline__ -#define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; // Setup acquire-release semantics for vector memory reads (mubuf instruction) // as per architecture. @@ -57,27 +58,31 @@ static constexpr int kAtomStride = kBlockSize; // process. static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + // Standard CDNA wavefront size. static constexpr int kWavefront = 64; // 256 thread, 4 wavefronts. -static dim3 constexpr kBlock = {64, 4, 1}; +static dim3 constexpr kBlock = {kWavefront, kBlockSize / kWavefront, 1}; // Number of threads in a group for quantization // It corresponds to 32 F16 elements in quantization block static constexpr int kThreadGroupSize = 8; // Methods -__device_inline__ __host__ unsigned long divceil(unsigned long x, - unsigned long y) { +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, + unsigned long y) { return ((x + y - 1) / y); } union BufferResource { - __device_inline__ constexpr BufferResource() : config(0x00020000U) {} + __quickreduce_device_inline__ constexpr BufferResource() + : config(0x00020000U) {} - __device_inline__ constexpr BufferResource(void* buffer_address, - uint32_t buffer_size) + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) : address(buffer_address), range(buffer_size), config(0x00020000U) {} int32x4_t descriptor; @@ -89,18 +94,18 @@ union BufferResource { }; }; -__device_inline__ static int32x4_t buffer_load_dwordx4( +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); -__device_inline__ static void buffer_store_dwordx4( +__quickreduce_device_inline__ static void buffer_store_dwordx4( int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); -// Setting fp16 flag does not seem to have an effect for gfx942 +// NOTE: Setting fp16 flag does not seem to have an effect for gfx942 // The register offset has to be validated // So we don't use it in Codecs for now. -__device_inline__ static void set_fp16_ovfl(bool const value) { +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { // short size = 0b00001; // Specifies the bit size to modify // const short offset = 0b10111; // Corrected offset to 23, which is the bit // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for @@ -117,10 +122,12 @@ __device_inline__ static void set_fp16_ovfl(bool const value) { } template -__device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B); template <> -__device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B) { int32x4_t& tR_fragment = A[0]; int32x4_t& tA_fragment = B[0]; @@ -139,8 +146,8 @@ __device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { } template <> -__device_inline__ void packed_assign_add(int32x4_t* A, - int32x4_t* B) { +__quickreduce_device_inline__ void packed_assign_add( + int32x4_t* A, int32x4_t* B) { nv_bfloat162* tA = reinterpret_cast(A); nv_bfloat162* tB = reinterpret_cast(B); #pragma unroll @@ -150,17 +157,17 @@ __device_inline__ void packed_assign_add(int32x4_t* A, } template -__device_inline__ int packed_max(int a, int b); +__quickreduce_device_inline__ int packed_max(int a, int b); template <> -__device_inline__ int packed_max(int a, int b) { +__quickreduce_device_inline__ int packed_max(int a, int b) { int result; asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); return result; } template <> -__device_inline__ int packed_max(int a, int b) { +__quickreduce_device_inline__ int packed_max(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); nv_bfloat162 tR = __hmax2(*tA, *tB); @@ -168,17 +175,17 @@ __device_inline__ int packed_max(int a, int b) { } template -__device_inline__ int packed_min(int a, int b); +__quickreduce_device_inline__ int packed_min(int a, int b); template <> -__device_inline__ int packed_min(int a, int b) { +__quickreduce_device_inline__ int packed_min(int a, int b) { int result; asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); return result; } template <> -__device_inline__ int packed_min(int a, int b) { +__quickreduce_device_inline__ int packed_min(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); nv_bfloat162 tR = __hmin2(*tA, *tB); @@ -186,10 +193,10 @@ __device_inline__ int packed_min(int a, int b) { } template -__device_inline__ int packed_abs_max(int a, int b); +__quickreduce_device_inline__ int packed_abs_max(int a, int b); template <> -__device_inline__ int packed_abs_max(int a, int b) { +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { half2 wmaxh2 = __builtin_bit_cast(half2, a); half2 wminh2 = __builtin_bit_cast(half2, b); half2 wblockmaxh2; @@ -202,7 +209,7 @@ __device_inline__ int packed_abs_max(int a, int b) { } template <> -__device_inline__ int packed_abs_max(int a, int b) { +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { nv_bfloat162 wmaxh2 = *(reinterpret_cast(&a)); nv_bfloat162 wminh2 = *(reinterpret_cast(&b)); nv_bfloat162 wblockmaxh2; @@ -215,17 +222,17 @@ __device_inline__ int packed_abs_max(int a, int b) { } template -__device_inline__ int packed_add(int a, int b); +__quickreduce_device_inline__ int packed_add(int a, int b); template <> -__device_inline__ int packed_add(int a, int b) { +__quickreduce_device_inline__ int packed_add(int a, int b) { int result; asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); return result; } template <> -__device_inline__ int packed_add(int a, int b) { +__quickreduce_device_inline__ int packed_add(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); nv_bfloat162 tR = __hadd2(*tA, *tB); @@ -233,17 +240,17 @@ __device_inline__ int packed_add(int a, int b) { } template <> -__device_inline__ int packed_add(int a, int b) { +__quickreduce_device_inline__ int packed_add(int a, int b) { int result; asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); return result; } template -__device_inline__ int packed_sub(int a, int b); +__quickreduce_device_inline__ int packed_sub(int a, int b); template <> -__device_inline__ int packed_sub(int a, int b) { +__quickreduce_device_inline__ int packed_sub(int a, int b) { int result; // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max @@ -254,7 +261,7 @@ __device_inline__ int packed_sub(int a, int b) { } template <> -__device_inline__ int packed_sub(int a, int b) { +__quickreduce_device_inline__ int packed_sub(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); nv_bfloat162 tR = __hsub2(*tA, *tB); @@ -262,17 +269,17 @@ __device_inline__ int packed_sub(int a, int b) { } template -__device_inline__ int packed_mul(int a, int b); +__quickreduce_device_inline__ int packed_mul(int a, int b); template <> -__device_inline__ int packed_mul(int a, int b) { +__quickreduce_device_inline__ int packed_mul(int a, int b) { int result; asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); return result; } template <> -__device_inline__ int packed_mul(int a, int b) { +__quickreduce_device_inline__ int packed_mul(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); nv_bfloat162 tR = __hmul2(*tA, *tB); @@ -280,48 +287,48 @@ __device_inline__ int packed_mul(int a, int b) { } template -__device_inline__ int packed_rcp(int a); +__quickreduce_device_inline__ int packed_rcp(int a); template <> -__device_inline__ int packed_rcp(int a) { +__quickreduce_device_inline__ int packed_rcp(int a) { return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); } template <> -__device_inline__ int packed_rcp(int a) { +__quickreduce_device_inline__ int packed_rcp(int a) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162 tR = h2rcp(*tA); return *(reinterpret_cast(&tR)); } template -__device_inline__ T float2T_cast(float a); +__quickreduce_device_inline__ T float2T_cast(float a); template <> -__device_inline__ half float2T_cast(float a) { +__quickreduce_device_inline__ half float2T_cast(float a) { return __float2half(a); } template <> -__device_inline__ nv_bfloat16 float2T_cast(float a) { +__quickreduce_device_inline__ nv_bfloat16 float2T_cast(float a) { return __float2bfloat16(a); } template -__device_inline__ float T2float_cast(T a); +__quickreduce_device_inline__ float T2float_cast(T a); template <> -__device_inline__ float T2float_cast(half a) { +__quickreduce_device_inline__ float T2float_cast(half a) { return __half2float(a); } template <> -__device_inline__ float T2float_cast(nv_bfloat16 a) { +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } template -__device_inline__ int group_abs_max(int32x4_t atom) { +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; int wmax, wmin, wblockmax; @@ -353,8 +360,8 @@ __device_inline__ int group_abs_max(int32x4_t atom) { } template -__device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, - int& wblockmin) { +__quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, + int& wblockmin) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; int wmax, wmin; @@ -383,11 +390,11 @@ __device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, wblockmin = __shfl(wmin, group_leader); } -__device_inline__ void set_sync_flag(int* flag_ptr, int flag) { +__quickreduce_device_inline__ void set_sync_flag(int* flag_ptr, int flag) { __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); } -__device_inline__ void wait_sync_flag(int* flag_ptr, int flag) { +__quickreduce_device_inline__ void wait_sync_flag(int* flag_ptr, int flag) { while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { } } diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 178631dac7bc..f58a1b88991c 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -21,15 +21,12 @@ static_assert(sizeof(void*) == sizeof(fptr_t)); enum QuickReduceAlgo { ONESHOT_FP16 = 0, TWOSHOT_FP16 = 1, - TWOSHOT_Q8 = 2, - TWOSHOT_Q4 = 3, - TWOSHOT_MAX_MIN_Q8 = 4, - TWOSHOT_MAX_MIN_Q4 = 5, + TWOSHOT_SYMM_Q8 = 2, + TWOSHOT_SYMM_Q4 = 3, + TWOSHOT_ASYMM_Q8 = 4, + TWOSHOT_ASYMM_Q4 = 5, }; -// ============================================================ -// KERNEL -// ============================================================ template __global__ __quickreduce_launch_bounds__ static void allreduce_prototype( T const* A, T* B, int N, int num_blocks, int world_size, int rank, @@ -44,9 +41,6 @@ __global__ __quickreduce_launch_bounds__ static void allreduce_prototype( } } -// ============================================================ -// DISPATCH -// ============================================================ #define TWOSHOT_DISPATCH(__codec) \ if (world_size == 2) { \ using LineCodec = __codec; \ @@ -71,11 +65,6 @@ __global__ __quickreduce_launch_bounds__ static void allreduce_prototype( flag_color); \ } -/* -=============================================================== -Desc: - Device Comms Handle -*/ struct DeviceComms { // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) static long constexpr kTileSize = 256 * 16 * 8; @@ -172,7 +161,7 @@ struct DeviceComms { } template - void allreduce(int profile, hipStream_t stream, T const* A, T* B, int N) { + void allreduce(int algo_int, hipStream_t stream, T const* A, T* B, int N) { if (world_size != 2 && world_size != 4 && world_size != 8) { throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); @@ -181,10 +170,10 @@ struct DeviceComms { // Configuration. long msg_size = N * sizeof(T); unsigned long num_blocks = divceil(msg_size, kTileSize); - unsigned long grid = min(304 * 4, num_blocks); - // ------------------------------------------------- + unsigned long grid = min(kMaxNumBlocks, num_blocks); + // All reduce dispatch. - QuickReduceAlgo algo = static_cast(profile); + QuickReduceAlgo algo = static_cast(algo_int); switch (algo) { case QuickReduceAlgo::ONESHOT_FP16: @@ -194,16 +183,16 @@ struct DeviceComms { num_blocks, world_size, rank, dbuffer_list, data_offset, flag_color); break; - case QuickReduceAlgo::TWOSHOT_Q8: + case QuickReduceAlgo::TWOSHOT_SYMM_Q8: TWOSHOT_DISPATCH(CodecQ8Symm) break; - case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q8: + case QuickReduceAlgo::TWOSHOT_ASYMM_Q8: TWOSHOT_DISPATCH(CodecQ8Asymm) break; - case QuickReduceAlgo::TWOSHOT_Q4: + case QuickReduceAlgo::TWOSHOT_SYMM_Q4: TWOSHOT_DISPATCH(CodecQ4Symm) break; - case QuickReduceAlgo::TWOSHOT_MAX_MIN_Q4: + case QuickReduceAlgo::TWOSHOT_ASYMM_Q4: TWOSHOT_DISPATCH(CodecQ4Asymm) break; default: @@ -212,7 +201,6 @@ struct DeviceComms { } HIP_CHECK(cudaGetLastError()); - // ------------------------------------------------- // Rotate the flag color. flag_color++; } diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 54f5d209a507..c546459d9490 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -9,7 +9,7 @@ struct CodecBase { const int thread; const int rank; const int group_leader; - __device_inline__ CodecBase(int thread, int rank) + __quickreduce_device_inline__ CodecBase(int thread, int rank) : thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {} @@ -32,18 +32,19 @@ struct CodecFP16 : public CodecBase { static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; - __device_inline__ CodecFP16(int thread, int rank) : CodecBase(thread, rank) {} + __quickreduce_device_inline__ CodecFP16(int thread, int rank) + : CodecBase(thread, rank) {} - __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { for (int i = 0; i < kRankAtoms; i++) { __builtin_nontemporal_store(data[i], send_buffer + thread); send_buffer += kAtomStride; } } - __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { for (int i = 0; i < kRankAtoms; i++) { data[i] = __builtin_nontemporal_load(*recv_buffer + thread); *recv_buffer += kAtomStride; @@ -96,14 +97,14 @@ struct CodecQ4Symm : public CodecBase { // {+8, +8}, int16x2_t static constexpr int kRangeBias = 0x00080008; - __device_inline__ CodecQ4Symm(int thread, int rank) + __quickreduce_device_inline__ CodecQ4Symm(int thread, int rank) : CodecBase(thread, rank) { // if constexpr (std::is_same::value) // set_fp16_ovfl(true); } - __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; @@ -157,8 +158,8 @@ struct CodecQ4Symm : public CodecBase { } } - __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); @@ -243,14 +244,14 @@ struct CodecQ4Asymm : public CodecBase { std::is_same::value ? 0x4B804B80 : 0x41704170; static unsigned char constexpr kMask0F = 0x0F; - __device_inline__ CodecQ4Asymm(int thread, int rank) + __quickreduce_device_inline__ CodecQ4Asymm(int thread, int rank) : CodecBase(thread, rank) { // if constexpr (std::is_same::value) // set_fp16_ovfl(true); } - __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; @@ -307,8 +308,8 @@ struct CodecQ4Asymm : public CodecBase { } } - __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); @@ -391,14 +392,14 @@ struct CodecQ8Symm : public CodecBase { // {+128, +128}, int16x2_t static constexpr int kRangeBias = 0x00800080; - __device_inline__ CodecQ8Symm(int thread, int rank) + __quickreduce_device_inline__ CodecQ8Symm(int thread, int rank) : CodecBase(thread, rank) { // if constexpr (std::is_same::value) // set_fp16_ovfl(true); } - __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; @@ -452,8 +453,8 @@ struct CodecQ8Symm : public CodecBase { } } - __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ output) { + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ output) { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); @@ -540,14 +541,14 @@ struct CodecQ8Asymm : public CodecBase { // {+128, +128}, int16x2_t static constexpr int kRangeBias = 0x00800080; - __device_inline__ CodecQ8Asymm(int thread, int rank) + __quickreduce_device_inline__ CodecQ8Asymm(int thread, int rank) : CodecBase(thread, rank) { // if constexpr (std::is_same::value) // set_fp16_ovfl(true); } - __device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ input) { + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ input) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = input[k]; @@ -602,8 +603,8 @@ struct CodecQ8Asymm : public CodecBase { } } - __device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 686444eb4237..1f9314364595 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -23,10 +23,10 @@ class QuickReduceAlgo(Enum): OneShotFP = 0 TwoShotFP = 1 - TwoShotQ8SYMM = 2 - TwoShotQ4SYMM = 3 - TwoShotQ8ASYMM = 4 - TwoShotQ4ASYMM = 5 + TwoShotQ8Symm = 2 + TwoShotQ4Symm = 3 + TwoShotQ8Asymm = 4 + TwoShotQ4Asymm = 5 class QuickAllReduce: From c1986569515ce3657542a15a2f8e808b77a9f733 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 3 Jun 2025 15:49:30 +0000 Subject: [PATCH 07/28] Remove config param. Add faster low latency OneShot algo Signed-off-by: ilmarkov --- csrc/custom_quickreduce.cu | 17 +- csrc/ops.h | 3 +- csrc/quickreduce/base.h | 103 +-- csrc/quickreduce/quick_reduce.h | 148 ++-- csrc/quickreduce/quick_reduce_impl.cuh | 636 +++--------------- csrc/torch_bindings.cpp | 6 +- tests/distributed/test_quick_reduce.py | 7 +- vllm/_custom_ops.py | 8 +- vllm/config.py | 9 - .../device_communicators/cuda_communicator.py | 28 +- .../device_communicators/quick_all_reduce.py | 42 +- vllm/distributed/parallel_state.py | 10 - vllm/engine/arg_utils.py | 8 - vllm/entrypoints/llm.py | 4 - vllm/v1/worker/gpu_worker.py | 4 +- vllm/worker/worker.py | 3 +- 16 files changed, 259 insertions(+), 777 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index b3c8f9ffee1e..27f018fa738d 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -52,7 +52,7 @@ void qr_open_handles(quickreduce::fptr_t _fa, } void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, - torch::Tensor& out, int64_t algo_int) { + torch::Tensor& out, bool quantized) { auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); @@ -60,17 +60,15 @@ void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); - auto algo = static_cast(algo_int); if (out.scalar_type() == at::ScalarType::Half) { - fa->allreduce(algo_int, stream, - reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel(), + quantized, stream); } else if (out.scalar_type() == at::ScalarType::BFloat16) { fa->allreduce( - algo_int, stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - out.numel()); + out.numel(), quantized, stream); } else { throw std::runtime_error( "quick allreduce only supports float16 and bfloat16"); @@ -81,9 +79,4 @@ int64_t qr_max_size() { return static_cast(quickreduce::DeviceComms::kMaxProblemSize); } -int64_t qr_min_size() { - return static_cast(quickreduce::kBlockSize * quickreduce::kAtoms * - sizeof(quickreduce::int32x4_t)); -} - #endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index fd9f3ac1d670..8b591aae4264 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -367,7 +367,6 @@ void qr_destroy(fptr_t _fa); torch::Tensor qr_get_handle(fptr_t _fa); void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - int64_t algo_int); + bool quantized); int64_t qr_max_size(); -int64_t qr_min_size(); #endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index a54b1178a956..d833700391eb 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -6,12 +6,14 @@ #include #define __quickreduce_device_inline__ __device__ __forceinline__ -#define __quickreduce_launch_bounds__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) namespace quickreduce { typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; // Setup acquire-release semantics for vector memory reads (mubuf instruction) // as per architecture. @@ -25,26 +27,6 @@ typedef __hip_bfloat162 nv_bfloat162; #define MUBUF_RELEASE 0 #endif -// Vector types -using int8x8_t = __attribute__((__vector_size__(8 * sizeof(int8_t)))) int8_t; - -using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; -using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; -using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; -using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; - -using fp8_t = uint8_t; -using fp8x8_t = __attribute__((__vector_size__(8 * sizeof(uint8_t)))) uint8_t; - -using fp16x4_t = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16; -using fp16x8_t = __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16; -using fp16x16_t = __attribute__((__vector_size__(16 * sizeof(__fp16)))) __fp16; - -using fp32x2_t = __attribute__((__vector_size__(2 * sizeof(float)))) float; -using fp32x4_t = __attribute__((__vector_size__(4 * sizeof(float)))) float; -using fp32x8_t = __attribute__((__vector_size__(8 * sizeof(float)))) float; -using fp32x16_t = __attribute__((__vector_size__(16 * sizeof(float)))) float; - static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t // Number of atoms (4xf16x2_t) processed by a single thread @@ -65,7 +47,10 @@ static constexpr int kMaxNumBlocks = 304 * 4; static constexpr int kWavefront = 64; // 256 thread, 4 wavefronts. -static dim3 constexpr kBlock = {kWavefront, kBlockSize / kWavefront, 1}; +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +static constexpr int kThreadsOneShot = 512; +static dim3 constexpr kBlockOneShot = {kThreadsOneShot, 1, 1}; // Number of threads in a group for quantization // It corresponds to 32 F16 elements in quantization block @@ -102,25 +87,6 @@ __quickreduce_device_inline__ static void buffer_store_dwordx4( int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); -// NOTE: Setting fp16 flag does not seem to have an effect for gfx942 -// The register offset has to be validated -// So we don't use it in Codecs for now. -__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { - // short size = 0b00001; // Specifies the bit size to modify - // const short offset = 0b10111; // Corrected offset to 23, which is the bit - // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for - // MODE const short simm16 = (size << 11) | (offset << 6) | hwRegId; simm16 = - // 0xdc1 - -#if defined(__gfx942__) - if (value) { - asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); - } else { - asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); - } -#endif -} - template __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); @@ -327,6 +293,47 @@ __quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } +template +__quickreduce_device_inline__ unsigned char T2uchar_cast(T a); + +template <> +__quickreduce_device_inline__ unsigned char T2uchar_cast(half a) { + return static_cast(__half2ushort_rz(a)); +} + +template <> +__quickreduce_device_inline__ unsigned char T2uchar_cast( + nv_bfloat16 a) { + return static_cast(__bfloat16_as_ushort(a)); +} + +template +__quickreduce_device_inline__ T uchar2T_cast(unsigned char a); + +template <> +__quickreduce_device_inline__ half uchar2T_cast(unsigned char a) { + return __ushort2half_rz(static_cast(a)); +} + +template <> +__quickreduce_device_inline__ nv_bfloat16 +uchar2T_cast(unsigned char a) { + return __ushort_as_bfloat16(static_cast(a)); +} + +template +__quickreduce_device_inline__ int T2int_cast(T a); + +template <> +__quickreduce_device_inline__ int T2int_cast(half a) { + return __half2int_rz(a); +} + +template <> +__quickreduce_device_inline__ int T2int_cast(nv_bfloat16 a) { + return static_cast(__bfloat16_as_ushort(a)); +} + template __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; @@ -361,18 +368,26 @@ __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { template __quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, - int& wblockmin) { + int& wblockmin, + int valid_data) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + static constexpr int FP_MAX = + std::is_same::value ? 0x7BFF7BFF : 0x7F7F7F7F; + static constexpr int FP_MIN = + std::is_same::value ? 0xFBFFFBFF : 0xFF7FFF7F; int wmax, wmin; int a, b; a = packed_max(atom[0], atom[1]); b = packed_max(atom[2], atom[3]); - wmax = packed_max(a, b); + // In case the data was loaded out of range (and initialized to 0) + // we set max min values to sentinel values + // so that they do not spoil the group max min values + wmax = valid_data * packed_max(a, b) + (!valid_data) * FP_MIN; a = packed_min(atom[0], atom[1]); b = packed_min(atom[2], atom[3]); - wmin = packed_min(a, b); + wmin = valid_data * packed_min(a, b) + (!valid_data) * FP_MAX; // Reduce the max and min among a group of threads // Note: This is basically 2 blocks of values setup as the diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index f58a1b88991c..afa94db23465 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -18,51 +18,78 @@ namespace quickreduce { using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -enum QuickReduceAlgo { - ONESHOT_FP16 = 0, - TWOSHOT_FP16 = 1, - TWOSHOT_SYMM_Q8 = 2, - TWOSHOT_SYMM_Q4 = 3, - TWOSHOT_ASYMM_Q8 = 4, - TWOSHOT_ASYMM_Q4 = 5, -}; +static constexpr int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; +static constexpr int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; +static constexpr int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; +static constexpr int kOneShotAllreduceMaxSize = + std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2, + std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4, + kOneShotAllreduceMaxElemsWorldSize8 * 8)) * + sizeof(half); template -__global__ __quickreduce_launch_bounds__ static void allreduce_prototype( - T const* A, T* B, int N, int num_blocks, int world_size, int rank, - uint8_t** dbuffer_list, long data_offset, int flag_color) { +__global__ __quickreduce_launch_bounds_one_shot__ static void +allreduce_prototype_oneshot(T const* A, T* B, int N, int rank, + uint8_t** dbuffer_list, long data_offset, + int flag_color) { + AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color); +} + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void +allreduce_prototype_twoshot(T const* A, T* B, int N, int num_blocks, int rank, + uint8_t** dbuffer_list, long data_offset, + int flag_color) { int block = blockIdx.x; int grid = gridDim.x; while (block < num_blocks) { - AllReduceKernel::run(A, B, N, block, num_blocks, world_size, rank, - dbuffer_list, data_offset, flag_color); + AllReduceKernel::run(A, B, N, block, num_blocks, rank, dbuffer_list, + data_offset, flag_color); block += grid; } } -#define TWOSHOT_DISPATCH(__codec) \ - if (world_size == 2) { \ - using LineCodec = __codec; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ - } else if (world_size == 4) { \ - using LineCodec = __codec; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ - } else if (world_size == 8) { \ - using LineCodec = __codec; \ - using AllReduceKernel = AllReduceTwoshot; \ - hipLaunchKernelGGL((allreduce_prototype), dim3(grid), \ - dim3(kBlock), 0, stream, A, B, N, num_blocks, \ - world_size, rank, dbuffer_list, data_offset, \ - flag_color); \ +#define ONESHOT_DISPATCH() \ + if (world_size == 2) { \ + using AllReduceKernel = AllReduceOneshot; \ + hipLaunchKernelGGL((allreduce_prototype_oneshot), \ + dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ + rank, dbuffer_list, data_offset, flag_color); \ + } else if (world_size == 4) { \ + using AllReduceKernel = AllReduceOneshot; \ + hipLaunchKernelGGL((allreduce_prototype_oneshot), \ + dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ + rank, dbuffer_list, data_offset, flag_color); \ + } else if (world_size == 8) { \ + using AllReduceKernel = AllReduceOneshot; \ + hipLaunchKernelGGL((allreduce_prototype_oneshot), \ + dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ + rank, dbuffer_list, data_offset, flag_color); \ + } + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ } struct DeviceComms { @@ -99,8 +126,9 @@ struct DeviceComms { // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - long data_buffer_size = 2 * kMaxProblemSize; - long total_buffer_size = flags_buffer_size + data_buffer_size; + long long data_buffer_size = max( + 2 * kMaxProblemSize, static_cast(kOneShotAllreduceMaxSize)); + long long total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); @@ -161,7 +189,7 @@ struct DeviceComms { } template - void allreduce(int algo_int, hipStream_t stream, T const* A, T* B, int N) { + void allreduce(T const* A, T* B, int N, bool quantized, hipStream_t stream) { if (world_size != 2 && world_size != 4 && world_size != 8) { throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); @@ -169,38 +197,26 @@ struct DeviceComms { // Configuration. long msg_size = N * sizeof(T); - unsigned long num_blocks = divceil(msg_size, kTileSize); - unsigned long grid = min(kMaxNumBlocks, num_blocks); - - // All reduce dispatch. - QuickReduceAlgo algo = static_cast(algo_int); - - switch (algo) { - case QuickReduceAlgo::ONESHOT_FP16: - using AllReduceKernel = AllReduceOneshot; - hipLaunchKernelGGL((allreduce_prototype), - dim3(grid), dim3(kBlock), 0, stream, A, B, N, - num_blocks, world_size, rank, dbuffer_list, - data_offset, flag_color); - break; - case QuickReduceAlgo::TWOSHOT_SYMM_Q8: - TWOSHOT_DISPATCH(CodecQ8Symm) - break; - case QuickReduceAlgo::TWOSHOT_ASYMM_Q8: - TWOSHOT_DISPATCH(CodecQ8Asymm) - break; - case QuickReduceAlgo::TWOSHOT_SYMM_Q4: + bool use_one_shot_allreduce = + (world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or + (world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or + (world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8); + if (use_one_shot_allreduce) { + // Each thread processes blocks out of 4 elements + unsigned long num_blocks = divceil(N, (4 * kThreadsOneShot)); + unsigned long grid = min(kMaxNumBlocks, num_blocks); + ONESHOT_DISPATCH() + } else { + unsigned long num_blocks = divceil(msg_size, kTileSize); + unsigned long grid = min(kMaxNumBlocks, num_blocks); + + if (quantized) { TWOSHOT_DISPATCH(CodecQ4Symm) - break; - case QuickReduceAlgo::TWOSHOT_ASYMM_Q4: - TWOSHOT_DISPATCH(CodecQ4Asymm) - break; - default: + } else { TWOSHOT_DISPATCH(CodecFP16) - break; + } } HIP_CHECK(cudaGetLastError()); - // Rotate the flag color. flag_color++; } diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index c546459d9490..d544f46a8361 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -36,7 +36,8 @@ struct CodecFP16 : public CodecBase { : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + const int32x4_t* __restrict__ data, + const int* validity) { for (int i = 0; i < kRankAtoms; i++) { __builtin_nontemporal_store(data[i], send_buffer + thread); send_buffer += kAtomStride; @@ -98,13 +99,11 @@ struct CodecQ4Symm : public CodecBase { static constexpr int kRangeBias = 0x00080008; __quickreduce_device_inline__ CodecQ4Symm(int thread, int rank) - : CodecBase(thread, rank) { - // if constexpr (std::is_same::value) - // set_fp16_ovfl(true); - } + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + const int32x4_t* __restrict__ data, + const int* validity) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; @@ -199,583 +198,106 @@ struct CodecQ4Symm : public CodecBase { } }; -// Int4 Asymm quantization codec. -// We quantize the FP16/BF16 data to block-scaled Int4 in blocks of 4 * -// kThreadGroupSize. -template -struct CodecQ4Asymm : public CodecBase { - static constexpr int kWorldSize = world_size; - - static constexpr int kRankAtoms = kAtoms / kWorldSize; - - // Codec tile size process by this workgroup. - // Each thread processes a fragment of 4xf16x2_t (16B) values, - // into a int4x8_t (4B), 2 zeros and 2 scales shared by thread group (4 * - // kThreadGroupSize) values. 1024 + 128 + 128 - static constexpr int kRankTileStride = 1280; - static constexpr int kRankTileScaleOffset = 1024; - static constexpr int kRankTileZeroOffset = 1152; - static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; - static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTransmittedTileSize must be 16B aligned."); - - static constexpr int kRankBufferTileStride = - kRankTileStride / sizeof(int32x4_t); - - // Total size of transmitted tile. - static constexpr int kTransmittedTileSize = - kRankTransmittedTileSize * kWorldSize; - - // Constants configuration - - // {-1/16.0h, -1/16.0h}, f16x2_t - static constexpr int kScaleFactor = - std::is_same::value ? 0xAC00AC00 : 0xBD80BD80; - - // {1e-7, 1e-7}, f16x2_t - static constexpr int kScaleEpsilon = - std::is_same::value ? 0x00010001 : 0x33D733D7; - - // {0, 0}, f16x2_t - static constexpr int kRangeMin = 0x00000000; - - // {+15, +15}, f16x2_t - static constexpr int kRangeMax = - std::is_same::value ? 0x4B804B80 : 0x41704170; - - static unsigned char constexpr kMask0F = 0x0F; - __quickreduce_device_inline__ CodecQ4Asymm(int thread, int rank) - : CodecBase(thread, rank) { - // if constexpr (std::is_same::value) - // set_fp16_ovfl(true); - } - - __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { - for (int k = 0; k < kRankAtoms; k++) { - int32x4_t const atom = data[k]; - - int wblockmax, wblockmin; - // Find max/min in thread group. - // In 2 blocks of values, upper/lower halves of the f16x2_t - group_max_min(atom, wblockmax, wblockmin); - - // Derive zeros and scales - int decoding_zero = wblockmin; - int decoding_scale; - int encoding_scale; - - decoding_scale = packed_sub(wblockmax, decoding_zero); - decoding_scale = packed_mul(decoding_scale, kScaleFactor); - encoding_scale = packed_add(decoding_scale, kScaleEpsilon); - encoding_scale = packed_rcp(encoding_scale); - - // Apply scales to get quantized values - int32x4_t w; - for (int i = 0; i < 4; i++) { - w[i] = packed_sub(atom[i], decoding_zero); - w[i] = packed_mul(w[i], encoding_scale); - w[i] = packed_max(w[i], kRangeMin); - w[i] = packed_min(w[i], kRangeMax); - } - - // Convert from f16 to a byte and pack into 4 bits. - int32_t qw = 0; - { - unsigned char* qi = reinterpret_cast(&qw); - T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) { - auto val = (unsigned char)T2float_cast(wh[i]) & kMask0F; - qi[i / 2] |= val << (4 * (i & 1)); - } - } - - // Write quantized atom to send_buffer - // note: only the group leader stores the scale - uint8_t* atom_ptr = - reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - int* qz_ptr = - reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - - __builtin_nontemporal_store(qw, qw_ptr); - if (threadIdx.x == group_leader) { - __builtin_nontemporal_store(decoding_scale, qs_ptr); - __builtin_nontemporal_store(decoding_zero, qz_ptr); - } - } - } - - __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { - for (int k = 0; k < kRankAtoms; k++) { - // Directly read quantized atom from recv_buffer - uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - int* qz_ptr = - reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - - int32_t qw = __builtin_nontemporal_load(qw_ptr); - int qs = __builtin_nontemporal_load(qs_ptr); - int qz = __builtin_nontemporal_load(qz_ptr); - - *recv_buffer += kRankBufferTileStride; - - // Unpack 8xq4 into f16x8_t - int32x4_t w; - { - T* wh = reinterpret_cast(&w); - unsigned char* qi = reinterpret_cast(&qw); - -#pragma unroll - for (int i = 0; i < 8; i++) { - auto val = (qi[i / 2] >> (4 * (i & 1))) & kMask0F; - wh[i] = float2T_cast((float)val); - } - } - - // Apply decoding scales and zeros - for (int i = 0; i < 4; i++) { - w[i] = packed_mul(w[i], qs); - w[i] = packed_add(w[i], qz); - } - - data[k] = w; - } - } -}; - -// Int8-blocking Line codec for Twoshot collectives. -// We quantize the FP16/BF16 data to block-scaled Int8 in blocks of 32. -template -struct CodecQ8Symm : public CodecBase { - static constexpr int kWorldSize = world_size; - - // Codec tile size process by this workgroup. - // Each threads processes a fragment of f16x8_t (16B), - // into a int8x8_t (8B) and a f16 scale shared among 32 values. - static constexpr int kRankAtoms = kAtoms / kWorldSize; - static constexpr int kRankTileStride = 2176; - static constexpr int kRankTileScaleOffset = 2048; - static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; - static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTransmittedTileSize must be 16B aligned."); - - static constexpr int kRankBufferTileStride = - kRankTileStride / sizeof(int32x4_t); - - // Total tile size for the collective communication. - static constexpr int kTransmittedTileSize = - kRankTransmittedTileSize * kWorldSize; - - // Constants configuration - - // {-1/128.0h, -1/128.0h}, f16x2_t - static constexpr int kScaleFactor = - std::is_same::value ? 0xA000A000 : 0xBC00BC00; - - // {1e-7, 1e-7}, f16x2_t - static constexpr int kScaleEpsilon = - std::is_same::value ? 0x00010001 : 0x33D733D7; - - // {-128, -128}, f16x2_t - static constexpr int kRangeMin = - std::is_same::value ? 0xD800D800 : 0xC300C300; - // {+127, +127}, f16x2_t - static constexpr int kRangeMax = - std::is_same::value ? 0x57F057F0 : 0x42FE42FE; - - // {+128, +128}, int16x2_t - static constexpr int kRangeBias = 0x00800080; - - __quickreduce_device_inline__ CodecQ8Symm(int thread, int rank) - : CodecBase(thread, rank) { - // if constexpr (std::is_same::value) - // set_fp16_ovfl(true); - } - - __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { - for (int k = 0; k < kRankAtoms; k++) { - int32x4_t const atom = data[k]; - - // Find abs max in thread group - int wblockmax = group_abs_max(atom); - // Derive scales - int decoding_scale; - int encoding_scale; - decoding_scale = packed_mul(wblockmax, kScaleFactor); - encoding_scale = packed_add(decoding_scale, kScaleEpsilon); - encoding_scale = packed_rcp(encoding_scale); - - // Apply scales to get quantized values - int32x4_t w; - for (int i = 0; i < 4; i++) { - w[i] = packed_mul(atom[i], encoding_scale); - w[i] = packed_max(w[i], kRangeMin); - w[i] = packed_min(w[i], kRangeMax); - } - - // Convert from f16x2_t to uint16x2_t - int32x4_t q; - { - int16_t* qi = reinterpret_cast(&q); - T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); - - for (int i = 0; i < 4; i++) { - q[i] = packed_add(q[i], kRangeBias); - } - } - - // Pack 8 x q8 into int32x2_t - int32x2_t qw; - qw[0] = q[0] | (q[1] << 8); - qw[1] = q[2] | (q[3] << 8); - - uint8_t* atom_ptr = - reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - - // Write quantized atom to send_buffer - __builtin_nontemporal_store(qw, qw_ptr); - // note: only the group leader stores the scale - if (threadIdx.x == group_leader) { - __builtin_nontemporal_store(decoding_scale, qs_ptr); - } - } - } - - __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ output) { - for (int k = 0; k < kRankAtoms; k++) { - // Directly read quantized atom from recv_buffer - uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - - int32x2_t qw = __builtin_nontemporal_load(qw_ptr); - int qs = __builtin_nontemporal_load(qs_ptr); - - *recv_buffer += kRankBufferTileStride; - - // Unpack q8 into fp16x8_t - int32x4_t w; - { - static constexpr uint kMask00FF = 0x00FF00FF; - - // {1024.0, 1024.0}, f16x2_t - static constexpr uint kF162_1024 = - std::is_same::value ? 0x64006400 : 0x44804480; - - // {-1152.0, -1152.0}, f16x2_t - static constexpr uint kF162_1152 = - std::is_same::value ? 0xE480E480 : 0xC490C490; - -#pragma unroll - for (int i = 0; i < 4; i++) { - int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kF162_1024; - w[i] = packed_add(q8, kF162_1152); - } - } - - // Apply decoding scales - for (int i = 0; i < 4; i++) { - w[i] = packed_mul(w[i], qs); - } - output[k] = w; - } - } -}; - -// Int8-blocking Line codec for Twoshot collectives. -// We quantize the FP16 data to block-scaled Int8 in blocks of 32. -template -struct CodecQ8Asymm : public CodecBase { - static constexpr int kWorldSize = world_size; - - // Codec tile size process by this workgroup. - // Each thread processes a fragment of fp16x8_t (16B), - // into a int8x8_t (8B) and a fp16 zero and a fp16 scale shared among 32 - // values. - static constexpr int kRankAtoms = kAtoms / kWorldSize; - // 2048 + 128 + 128 - static constexpr int kRankTileStride = 2304; - static constexpr int kRankTileScaleOffset = 2048; - static constexpr int kRankTileZeroOffset = 2176; - static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; - static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTransmittedTileSize must be 16B aligned."); - - static constexpr int kRankBufferTileStride = - kRankTileStride / sizeof(int32x4_t); - - // Total tile size for the collective communication. - static constexpr int kTransmittedTileSize = - kRankTransmittedTileSize * kWorldSize; - - // Constants configuration - // {1/255.0h, 1/255.0h}, f16x2_t - static constexpr int kScaleFactor = - std::is_same::value ? 0x1C041C04 : 0x3B813B81; - - // {1e-7, 1e-7}, fp16x2_t - static constexpr int kScaleEpsilon = - std::is_same::value ? 0x00010001 : 0x33D733D7; - - // {0, 0}, f16x2_t - static constexpr int kRangeMin = 0x00000000; - - // {+255, +255}, f16x2_t - static constexpr int kRangeMax = - std::is_same::value ? 0x5BF85BF8 : 0x437F437F; - - // {+128, +128}, int16x2_t - static constexpr int kRangeBias = 0x00800080; - - __quickreduce_device_inline__ CodecQ8Asymm(int thread, int rank) - : CodecBase(thread, rank) { - // if constexpr (std::is_same::value) - // set_fp16_ovfl(true); - } - - __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ input) { - for (int k = 0; k < kRankAtoms; k++) { - int32x4_t const atom = input[k]; - - int wblockmax, wblockmin; - // Find max/min in thread group. - // In 2 blocks of values, upper/lower halves of the f16x2_t - group_max_min(atom, wblockmax, wblockmin); - - // Derive zeros and scales - int decoding_zero = wblockmin; - int decoding_scale; - int encoding_scale; - - decoding_scale = packed_sub(wblockmax, decoding_zero); - decoding_scale = packed_mul(decoding_scale, kScaleFactor); - encoding_scale = packed_add(decoding_scale, kScaleEpsilon); - encoding_scale = packed_rcp(encoding_scale); - - // Apply scales to get quantized values - int32x4_t w; - for (int i = 0; i < 4; i++) { - w[i] = packed_sub(atom[i], decoding_zero); - w[i] = packed_mul(w[i], encoding_scale); - w[i] = packed_max(w[i], kRangeMin); - w[i] = packed_min(w[i], kRangeMax); - } - - // Convert from fp16x8_t to uint8x8_t and pack into int32x2_t - int32x2_t qw; - { - unsigned char* qi = reinterpret_cast(&qw); - T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (unsigned char)T2float_cast(wh[i]); - } - - // Write quantized atom to send_buffer - // note: only the group leader stores the scale - uint8_t* atom_ptr = - reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - int* qz_ptr = - reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - - __builtin_nontemporal_store(qw, qw_ptr); - if (threadIdx.x == group_leader) { - __builtin_nontemporal_store(decoding_scale, qs_ptr); - __builtin_nontemporal_store(decoding_zero, qz_ptr); - } - } - } - - __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, - int32x4_t* __restrict__ data) { - for (int k = 0; k < kRankAtoms; k++) { - // Directly read quantized atom from recv_buffer - uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; - int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + - (thread / 8); - int* qz_ptr = - reinterpret_cast(atom_ptr + kRankTileZeroOffset) + (thread / 8); - - int32x2_t qw = __builtin_nontemporal_load(qw_ptr); - int qs = __builtin_nontemporal_load(qs_ptr); - int qz = __builtin_nontemporal_load(qz_ptr); - - *recv_buffer += kRankBufferTileStride; - - // Unpack uint8x8_t into fp16x8_t - int32x4_t w; - { - T* wh = reinterpret_cast(&w); - unsigned char* qi = reinterpret_cast(&qw); -#pragma unroll - for (int i = 0; i < 8; i++) { - wh[i] = float2T_cast((float)qi[i]); - } - } - - // Apply decoding scales and zeros - for (int i = 0; i < 4; i++) { - w[i] = packed_mul(w[i], qs); - w[i] = packed_add(w[i], qz); - } - - data[k] = w; - } - } -}; - // Oneshot AllReduce -template +template struct AllReduceOneshot { static_assert(sizeof(T) == 2); __device__ static void run( - T const* __restrict__ input, // input - T* __restrict__ output, // output + T const* __restrict__ A, // input + T* __restrict__ B, // output int const N, // number of elements - int const block, // this block's index - int const num_blocks, // total number of blocks - int const world_size, // total number of ranks - int const rank, // this rank's index + int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers long const data_offset, // offset to start of the data buffer - int flag_color // Flag color for the network barrier - ) { - // Topology - int thread = threadIdx.x + threadIdx.y * kWavefront; - - long data_stride = num_blocks * kTileSize; - long flags_stride = num_blocks * sizeof(int); + int flag_color) { + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + BufferResource dst_buffer(B, N * sizeof(T)); uint8_t* rank_buffer = buffer_list[rank]; - // -------------------------------------------------------- - // Read input into registers - int32x4_t tA[kAtoms]; + const int block_size = blockDim.x; + const int thread = threadIdx.x; + const int block = blockIdx.x; + const int problem_size = (N + 3) / 4; - BufferResource src_buffer(const_cast(input), N * sizeof(T)); - - int src_offset = block * kTileSize + thread * sizeof(int32x4_t); - - for (int i = 0; i < kAtoms; i++) { - tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); - src_offset += kAtomStride * sizeof(int32x4_t); - } - - // -------------------------------------------------------- - // Write rank data into this rank segment of every rank's communication - // buffer. - long comm_data_offset = - data_offset + rank * data_stride + block * kTileSize; - long comm_flags_offset = rank * flags_stride + block * sizeof(int); + int32x4_t tA, tB; + long grid = gridDim.x; + long data_stride = grid * block_size * sizeof(int32x4_t); + long comm_flags0_offset = block * (world_size * sizeof(int)); + long comm_flags1_offset = + comm_flags0_offset + grid * (world_size * sizeof(int)); - if (thread < world_size) { - int r = thread; - int* flag_ptr = - reinterpret_cast(buffer_list[r] + comm_flags_offset); - wait_sync_flag(flag_ptr, flag_color - 1); - } - __syncthreads(); + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + // load values + tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = - reinterpret_cast(buffer_list[r] + comm_data_offset); - for (int i = 0; i < kAtoms; i++) { - __builtin_nontemporal_store(tA[i], send_buffer + thread); - send_buffer += kAtomStride; + // Write rank data into this rank segment of every rank's communication + // buffer. +#pragma unroll + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + data_offset + rank * data_stride + + idx * sizeof(int32x4_t)); + __builtin_nontemporal_store(tA, send_buffer); } } - // Inform the other ranks that th data has been posted. __syncthreads(); if (thread < world_size) { int r = thread; - int* flag_ptr = - reinterpret_cast(buffer_list[r] + comm_flags_offset); - set_sync_flag(flag_ptr, flag_color); - } - - // Read and reduce the data from this rank's communication buffer. - int32x4_t tB[kAtoms]; - - { - int r = 0; + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags0_offset + r * sizeof(int)); // Wait for the flags to be set. - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - if (thread == 0) { - wait_sync_flag(flag_ptr, flag_color); - } - __syncthreads(); - - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + block * kTileSize); - - for (int i = 0; i < kAtoms; i++) { - tB[i] = __builtin_nontemporal_load(recv_buffer + thread); - recv_buffer += kAtomStride; + while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { } } + __syncthreads(); - for (int r = 1; r < world_size; r++) { - // Wait for the flags to be set. - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - if (thread == 0) { - wait_sync_flag(flag_ptr, flag_color); + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + { + int r = 0; + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tA = __builtin_nontemporal_load(recv_buffer); } - __syncthreads(); - - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + block * kTileSize); - - for (int i = 0; i < kAtoms; i++) { - tA[i] = __builtin_nontemporal_load(recv_buffer + thread); - recv_buffer += kAtomStride; +#pragma unroll + for (int r = 1; r < world_size; r++) { + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tB = __builtin_nontemporal_load(recv_buffer); + + // Reduce the local data with the read data + packed_assign_add(&tA, &tB); } - // Reduce. - for (int i = 0; i < kAtoms; i++) { - packed_assign_add(&tB[i], &tA[i]); - } + buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); } __syncthreads(); if (thread < world_size) { int r = thread; - int* flag_ptr = reinterpret_cast(rank_buffer + r * flags_stride + - block * sizeof(int)); - set_sync_flag(flag_ptr, flag_color); - } - - // Write the result to output. - BufferResource dst_buffer(output, N * sizeof(T)); - int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags1_offset + r * sizeof(int)); - for (int i = 0; i < kAtoms; i++) { - buffer_store_dwordx4(tB[i], dst_buffer.descriptor, dst_offset, 0, 0); - dst_offset += kAtomStride * sizeof(int32x4_t); + // Wait for the flags to be set. + while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } } } }; @@ -789,11 +311,10 @@ struct AllReduceTwoshot { __device__ static void run( T const* __restrict__ input, T* __restrict__ output, - int const N, // number of elements - int const block, // block index - int const num_blocks, // number of blocks - int const world_size, // unused - only kept around for API consistency - int const rank, // rank index + int const N, // number of elements + int const block, // block index + int const num_blocks, // number of blocks + int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers long const data_offset, // offset to start of the data buffer int flag_color) { @@ -805,6 +326,7 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; + int tA_validity[kAtoms]; BufferResource src_buffer(const_cast(input), N * sizeof(T)); int src_offset = block * kTileSize + thread * sizeof(int32x4_t); @@ -812,6 +334,7 @@ struct AllReduceTwoshot { for (int i = 0; i < kAtoms; i++) { tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + tA_validity[i] = src_offset < N * sizeof(T); src_offset += kAtomStride * sizeof(int32x4_t); } @@ -830,7 +353,8 @@ struct AllReduceTwoshot { int32x4_t* send_buffer = reinterpret_cast(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize); - codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms], + &tA_validity[r * Codec::kRankAtoms]); } __syncthreads(); @@ -870,7 +394,7 @@ struct AllReduceTwoshot { int32x4_t* send_buffer = reinterpret_cast(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize); - codec.send(send_buffer, tR); + codec.send(send_buffer, tR, &tA_validity[rank * Codec::kRankAtoms]); } __syncthreads(); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f2ccdab0cf8f..0e8ab36ad98c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -728,7 +728,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { #ifdef USE_ROCM // Quick Reduce all-reduce kernels custom_ar.def( - "qr_all_reduce(int fa, Tensor inp, Tensor out, int algo_int) -> ()"); + "qr_all_reduce(int fa, Tensor inp, Tensor out, bool quantized) -> ()"); custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); custom_ar.def("init_custom_qr", &init_custom_qr); @@ -741,10 +741,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Max input size in bytes custom_ar.def("qr_max_size", &qr_max_size); - // Minimal input size in bytes. - // For inputs size less than this value - // AllReduce collective is ineffective - custom_ar.def("qr_min_size", &qr_min_size); #endif } diff --git a/tests/distributed/test_quick_reduce.py b/tests/distributed/test_quick_reduce.py index f4c213981f4e..69763d026834 100644 --- a/tests/distributed/test_quick_reduce.py +++ b/tests/distributed/test_quick_reduce.py @@ -9,11 +9,8 @@ from vllm.distributed.communication_op import ( # noqa tensor_model_parallel_all_reduce) -from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickReduceAlgo) from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture, - set_quick_reduce_algo) + get_tp_group, graph_capture) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -37,7 +34,6 @@ def graph_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - set_quick_reduce_algo(QuickReduceAlgo.TwoShotFP) init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) @@ -99,7 +95,6 @@ def eager_quick_allreduce( ): with monkeypatch.context() as m: m.delenv("CUDA_VISIBLE_DEVICES", raising=False) - set_quick_reduce_algo(QuickReduceAlgo.TwoShotFP) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8e44e306886c..e4965e9516fb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1768,8 +1768,8 @@ def qr_destroy(fa: int) -> None: def qr_all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, - algo_int: int) -> None: - torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, algo_int) + quantized: bool) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quantized) def qr_get_handle(fa: int) -> torch.Tensor: @@ -1784,10 +1784,6 @@ def qr_max_size() -> int: return torch.ops._C_custom_ar.qr_max_size() -def qr_min_size() -> int: - return torch.ops._C_custom_ar.qr_min_size() - - def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/config.py b/vllm/config.py index 49eb71ad7f91..d986ab6b0edb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,8 +33,6 @@ import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickReduceAlgo) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, QuantizationMethods, @@ -1772,13 +1770,6 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" - quick_reduce_allreduce_algo: Optional[QuickReduceAlgo] = None - """Enable alternative to custom all-reduce. - Supports asymmetric and symmetric quantization (4 and 8 bits) - for 2 Shot algorithm. Only supported on AMD, - for bf16 and fp16 input data types. - """ - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None """This parameter is deprecated and will be removed in a future release. Please remove it from your configs""" diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 55e9c68805c6..043370407e66 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -29,12 +29,10 @@ def __init__(self, if "tp" not in unique_name: # only tp uses custom allreduce use_custom_allreduce = False - quick_reduce_algo = None else: from vllm.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE, _QUICK_REDUCE_ALGO) + _ENABLE_CUSTOM_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - quick_reduce_algo = _QUICK_REDUCE_ALGO # ep does not use pynccl use_pynccl = "ep" not in unique_name @@ -60,25 +58,17 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce( - group=self.cpu_group, - device=self.device, - ) - self.quick_reduce_comm_algo = quick_reduce_algo - if (self.quick_reduce_comm_algo is not None - and not current_platform.is_rocm()): - logger.warning( - "quick_reduce_comm_algo is not None," - " but QuickReduce is only supported on ROCm platform.") - self.qr_comm: Optional[QuickAllReduce] = None + self.ca_comm = CustomAllreduce(group=self.cpu_group, + device=self.device, + max_size=8192 * 1024 * 2) - if (self.quick_reduce_comm_algo is not None - and current_platform.is_rocm() and self.world_size > 1): - # Initialize a custom fast all-reduce implementation + self.qr_comm: Optional[QuickAllReduce] = None + if (use_custom_allreduce and current_platform.is_rocm() + and self.world_size > 1): + # Initialize a custom fast all-reduce implementation for AMD # based on quick reduce (https://github.com/mk1-project/quickreduce). self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device, - algo=self.quick_reduce_comm_algo) + device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 1f9314364595..b4ac54d7ae78 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from enum import Enum +import os from typing import Union import torch @@ -20,23 +20,12 @@ ops_available = False -class QuickReduceAlgo(Enum): - OneShotFP = 0 - TwoShotFP = 1 - TwoShotQ8Symm = 2 - TwoShotQ4Symm = 3 - TwoShotQ8Asymm = 4 - TwoShotQ4Asymm = 5 - - class QuickAllReduce: _SUPPORTED_WORLD_SIZES = [2, 4, 8] _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] - def __init__(self, - group: ProcessGroup, - device: Union[int, str, torch.device], - algo: QuickReduceAlgo = QuickReduceAlgo.TwoShotFP) -> None: + def __init__(self, group: ProcessGroup, + device: Union[int, str, torch.device]) -> None: self.disabled = True if not ops_available: # disable because of missing quick reduce library @@ -46,14 +35,14 @@ def __init__(self, return self.max_size = ops.qr_max_size() - self.min_size = ops.qr_min_size() self.group = group - if isinstance(algo, str): - assert algo in QuickReduceAlgo.__members__, \ - "Algo {} is not supported by QuickReduce".format(algo) - self.algo = QuickReduceAlgo[algo] - else: - self.algo = algo + self.quantized = os.environ.get("VLLM_ROCM_CA_QUANTIZED", "0") == "1" + + # On RocM bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment is not set to 1 we convert input to fp16 + self.use_bf16_kernels = os.environ.get("VLLM_ROCM_CA_BF16_KERNELS", + "0") == "1" assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") @@ -67,7 +56,7 @@ def __init__(self, logger.warning( "QuickReduce allreduce is disabled due to an unsupported world" " size: %d. Supported world sizes: %s." - " To disable this warning set quick_reduce_allreduce_algo" + " To disable this warning set VLLM_ROCM_CA_BACKEND" " to None", world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) return @@ -105,11 +94,14 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): if inp_size >= self.max_size: return None + inp_dtype = inp.dtype + if inp_dtype == torch.bfloat16 and not self.use_bf16_kernels: + inp = inp.to(torch.float16) if out is None: out = torch.empty_like(inp) - ops.qr_all_reduce(self._ptr, inp, out, self.algo.value) - return out + ops.qr_all_reduce(self._ptr, inp, out, self.quantized) + return out.to(inp_dtype) def close(self): if not self.disabled and getattr(self, "_ptr", None): @@ -127,4 +119,4 @@ def should_quick_allreduce(self, inp: torch.Tensor): if inp_size % 16 != 0: return False return inp.dtype in QuickAllReduce._SUPPORTED_DTYPES and \ - inp_size < self.max_size and inp_size >= self.min_size + inp_size < self.max_size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c9af2aad722b..126160b09553 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -40,8 +40,6 @@ import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) -from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickReduceAlgo) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, get_distributed_init_method, @@ -910,14 +908,6 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -_QUICK_REDUCE_ALGO: Optional[QuickReduceAlgo] = None - - -def set_quick_reduce_algo(algo: Optional[QuickReduceAlgo]): - global _QUICK_REDUCE_ALGO - _QUICK_REDUCE_ALGO = algo - - def init_distributed_environment( world_size: int = -1, rank: int = -1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0a0c699f7713..f599d7a3bb5e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,8 +31,6 @@ SchedulerConfig, SchedulerPolicy, SpeculativeConfig, TaskOption, TokenizerMode, TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field) -from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickReduceAlgo) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods @@ -341,8 +339,6 @@ class EngineArgs: enforce_eager: bool = ModelConfig.enforce_eager max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - quick_reduce_allreduce_algo: Optional[ - QuickReduceAlgo] = ParallelConfig.quick_reduce_allreduce_algo # The following three fields are deprecated and will be removed in a future # release. Setting them will have no effect. Please remove them from your # configurations. @@ -666,9 +662,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--disable-custom-all-reduce", **parallel_kwargs["disable_custom_all_reduce"]) - parallel_group.add_argument( - "--quick-reduce-allreduce-algo", - **parallel_kwargs["quick_reduce_allreduce_algo"]) parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) parallel_group.add_argument("--worker-extension-cls", @@ -1130,7 +1123,6 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, - quick_reduce_allreduce_algo=self.quick_reduce_allreduce_algo, ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9fe28cda4372..c11e627ee236 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,8 +18,6 @@ BeamSearchSequence, get_beam_search_score) from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, is_init_field) -from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickReduceAlgo) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -177,7 +175,6 @@ def __init__( enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - quick_reduce_allreduce_algo: Optional[QuickReduceAlgo] = None, disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, @@ -252,7 +249,6 @@ def __init__( enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - quick_reduce_allreduce_algo=quick_reduce_allreduce_algo, disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 594c203803a6..58795e3fe292 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce, set_quick_reduce_algo) + set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -369,8 +369,6 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - set_quick_reduce_algo(parallel_config.quick_reduce_allreduce_algo) - init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 68bcc8fc9376..9a928632688a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,7 @@ from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, - set_custom_all_reduce, set_quick_reduce_algo) + set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -528,7 +528,6 @@ def init_worker_distributed_environment( """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - set_quick_reduce_algo(parallel_config.quick_reduce_allreduce_algo) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) From b47883f364515ae6665f6db034e93cc8154db154 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 3 Jun 2025 16:09:39 +0000 Subject: [PATCH 08/28] Some fixes Signed-off-by: ilmarkov --- csrc/quickreduce/quick_reduce.h | 11 +++++------ tests/distributed/test_quick_reduce.py | 3 +++ .../device_communicators/cuda_communicator.py | 5 +---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index afa94db23465..5620b1343d99 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -21,7 +21,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t)); static constexpr int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; static constexpr int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; static constexpr int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; -static constexpr int kOneShotAllreduceMaxSize = +static constexpr long kOneShotAllreduceMaxSize = std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2, std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4, kOneShotAllreduceMaxElemsWorldSize8 * 8)) * @@ -97,8 +97,7 @@ struct DeviceComms { static long constexpr kTileSize = 256 * 16 * 8; // Max problem size is 8GB (in bytes) - static long long constexpr kMaxProblemSize = - static_cast(536870912) * 16; + static long constexpr kMaxProblemSize = 8589934592; static long constexpr kMaxTiles = kMaxProblemSize / kTileSize; // Max TP-8 @@ -126,9 +125,9 @@ struct DeviceComms { // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - long long data_buffer_size = max( - 2 * kMaxProblemSize, static_cast(kOneShotAllreduceMaxSize)); - long long total_buffer_size = flags_buffer_size + data_buffer_size; + static constexpr long data_buffer_size = + std::max(2 * kMaxProblemSize, kOneShotAllreduceMaxSize); + long total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); diff --git a/tests/distributed/test_quick_reduce.py b/tests/distributed/test_quick_reduce.py index 69763d026834..32731de9a648 100644 --- a/tests/distributed/test_quick_reduce.py +++ b/tests/distributed/test_quick_reduce.py @@ -11,6 +11,7 @@ tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, get_tp_group, graph_capture) +from vllm.platforms import current_platform from ..utils import init_test_distributed_environment, multi_process_parallel @@ -111,6 +112,8 @@ def eager_quick_allreduce( torch.testing.assert_close(out, inp * (tp_size**num_communication)) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="Quick reduce is only supported on RocM.") @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 043370407e66..58191cac5677 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,17 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging from typing import Optional import torch from torch.distributed import ProcessGroup -import vllm.envs as envs -from vllm.platforms import current_platform - import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase From 78a2b7cb620d2f83762e8e7d9a6c42ea0358ed24 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 10 Jun 2025 16:48:46 +0000 Subject: [PATCH 09/28] fix bfloat16 recv Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce_impl.cuh | 27 ++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index d544f46a8361..38ba5367bf5b 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -175,16 +175,27 @@ struct CodecQ4Symm : public CodecBase { int32x4_t w; { static constexpr uint kMask000F = 0x000F000F; - // {1024.0, 1024.0}, f16x2_t - static constexpr uint kF162_1024 = - std::is_same::value ? 0x64006400 : 0x44804480; - // {-1032.0, -1032.0}, f16x2_t - static constexpr uint kF162_1032 = - std::is_same::value ? 0xE408E408 : 0xC481C481; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t for (int i = 0; i < 4; i++) { - int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kF162_1024; - w[i] = packed_add(q4, kF162_1032); + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q4), "v"(kHalf2_1032)); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } } } From 988e12475f0aadf508b8143b0493227b89dc462c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 11 Jun 2025 04:05:30 +0000 Subject: [PATCH 10/28] fix env Signed-off-by: Haoyang Li --- .../device_communicators/quick_all_reduce.py | 7 +++---- vllm/envs.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index b4ac54d7ae78..91af09b49cd9 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import os from typing import Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -36,13 +36,12 @@ def __init__(self, group: ProcessGroup, self.max_size = ops.qr_max_size() self.group = group - self.quantized = os.environ.get("VLLM_ROCM_CA_QUANTIZED", "0") == "1" + self.quantized = envs.VLLM_ROCM_CA_QUANTIZED # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 - self.use_bf16_kernels = os.environ.get("VLLM_ROCM_CA_BF16_KERNELS", - "0") == "1" + self.use_bf16_kernels = envs.VLLM_ROCM_CA_QUANTIZED assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") diff --git a/vllm/envs.py b/vllm/envs.py index a4a1784f97f9..8dca6c47c0d3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,6 +129,8 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_ROCM_CA_QUANTIZED: bool = False + VLLM_ROCM_CA_CAST_BF16_TO_FP16: bool = True def get_default_cache_root(): @@ -671,6 +673,22 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # Custom quick allreduce kernel for MI3* cards. + # Whether to use quantization to speed up allreduce. + # Recommended for large models to get doubled allreduce + # speedup with less precision loss. + "VLLM_ROCM_CA_QUANTIZED": + lambda: (os.getenv("VLLM_ROCM_CA_QUANTIZED", "False").lower() in + ("true", "1")), + + # Custom quick allreduce kernel for MI3* cards + # Due to the lack of the bfloat16 asm instruction, bfloat16 + # kernels are slower than fp16, + # If environment is not set to 1, we convert input to fp16 + "VLLM_ROCM_CA_CAST_BF16_TO_FP16": + lambda: (os.getenv("VLLM_ROCM_CA_QUANTIZED", "True").lower() in + ("true", "1")), + # If set, when running in Quark emulation mode, do not dequantize the # weights at load time. Instead, dequantize weights on-the-fly during # kernel execution. From 3ee13ffc3c9716e16b408bc776920aaa3b137e7e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 11 Jun 2025 06:39:39 +0000 Subject: [PATCH 11/28] fix log info Signed-off-by: Haoyang Li --- .../device_communicators/cuda_communicator.py | 2 +- .../device_communicators/quick_all_reduce.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 58191cac5677..f3f9e934c846 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -90,7 +90,7 @@ def __init__(self, def all_reduce(self, input_): # always try quick reduce first, then custom allreduce, - # and then pynccl. + # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm if qr_comm is not None and not qr_comm.disabled and \ qr_comm.should_quick_allreduce(input_): diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 91af09b49cd9..1770a8380a1e 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -30,8 +30,8 @@ def __init__(self, group: ProcessGroup, if not ops_available: # disable because of missing quick reduce library # e.g. in a non-cuda environment - logger.info("Custom allreduce is disabled because " - "of missing custom allreduce library") + logger.info("Custom quick allreduce is disabled because " + "of missing custom quick allreduce library") return self.max_size = ops.qr_max_size() @@ -41,8 +41,7 @@ def __init__(self, group: ProcessGroup, # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 - self.use_bf16_kernels = envs.VLLM_ROCM_CA_QUANTIZED - + self.use_bf16_kernels = envs.VLLM_ROCM_CA_CAST_BF16_TO_FP16 assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") rank = dist.get_rank(group=self.group) @@ -53,7 +52,7 @@ def __init__(self, group: ProcessGroup, if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: logger.warning( - "QuickReduce allreduce is disabled due to an unsupported world" + "QuickReduce is disabled due to an unsupported world" " size: %d. Supported world sizes: %s." " To disable this warning set VLLM_ROCM_CA_BACKEND" " to None", world_size, @@ -87,7 +86,8 @@ def create_shared_buffer(self): ops.qr_open_handles(self._ptr, handles) def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): - """Performs an out-of-place all reduce. + """ + Performs an out-of-place all reduce. """ inp_size = inp.numel() * inp.element_size() if inp_size >= self.max_size: From 309016ef83c67cfcddfecae7be6b71808ec8c960 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 11 Jun 2025 07:29:45 +0000 Subject: [PATCH 12/28] for env Signed-off-by: Haoyang Li --- vllm/distributed/device_communicators/quick_all_reduce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 1770a8380a1e..ff400fc8edc7 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -41,7 +41,7 @@ def __init__(self, group: ProcessGroup, # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 - self.use_bf16_kernels = envs.VLLM_ROCM_CA_CAST_BF16_TO_FP16 + self.use_fp16_kernels = envs.VLLM_ROCM_CA_CAST_BF16_TO_FP16 assert dist.get_backend(group) != dist.Backend.NCCL, ( "QuickReduce should be attached to a non-NCCL group.") rank = dist.get_rank(group=self.group) @@ -94,7 +94,7 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): return None inp_dtype = inp.dtype - if inp_dtype == torch.bfloat16 and not self.use_bf16_kernels: + if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: inp = inp.to(torch.float16) if out is None: out = torch.empty_like(inp) From 21bdda2b84a1479402c286523994366baf80accf Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 13 Jun 2025 12:15:53 +0000 Subject: [PATCH 13/28] Add int8 quantization. Remove changes to custom_allreduce Signed-off-by: ilmarkov --- csrc/custom_quickreduce.cu | 6 +- csrc/ops.h | 2 +- csrc/quickreduce/base.h | 18 ++ csrc/quickreduce/quick_reduce.h | 24 ++- csrc/quickreduce/quick_reduce_impl.cuh | 190 ++++++++++++++++-- csrc/torch_bindings.cpp | 2 +- vllm/_custom_ops.py | 4 +- .../device_communicators/cuda_communicator.py | 3 +- .../device_communicators/quick_all_reduce.py | 17 +- vllm/envs.py | 14 +- 10 files changed, 236 insertions(+), 44 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 27f018fa738d..79bc897c527b 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -52,7 +52,7 @@ void qr_open_handles(quickreduce::fptr_t _fa, } void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, - torch::Tensor& out, bool quantized) { + torch::Tensor& out, int64_t quant_level) { auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); @@ -63,12 +63,12 @@ void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, if (out.scalar_type() == at::ScalarType::Half) { fa->allreduce(reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel(), - quantized, stream); + quant_level, stream); } else if (out.scalar_type() == at::ScalarType::BFloat16) { fa->allreduce( reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - out.numel(), quantized, stream); + out.numel(), quant_level, stream); } else { throw std::runtime_error( "quick allreduce only supports float16 and bfloat16"); diff --git a/csrc/ops.h b/csrc/ops.h index 8b591aae4264..85f86f9722d1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -367,6 +367,6 @@ void qr_destroy(fptr_t _fa); torch::Tensor qr_get_handle(fptr_t _fa); void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - bool quantized); + int64_t quant_level); int64_t qr_max_size(); #endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index d833700391eb..5eb0723e5863 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -13,6 +13,8 @@ namespace quickreduce { typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; // Setup acquire-release semantics for vector memory reads (mubuf instruction) @@ -87,6 +89,22 @@ __quickreduce_device_inline__ static void buffer_store_dwordx4( int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { + // short size = 0b00001; // Specifies the bit size to modify + // const short offset = 0b10111; // Corrected offset to 23, which is the bit + // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for + // MODE const short simm16 = (size << 11) | (offset << 6) | hwRegId; simm16 = + // 0xdc1 + +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} + template __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 5620b1343d99..e45cc4f7758a 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -92,6 +92,12 @@ allreduce_prototype_twoshot(T const* A, T* B, int N, int num_blocks, int rank, flag_color); \ } +enum QuickReduceQuantLevel { + FP16 = 0, + INT8, + INT4, +}; + struct DeviceComms { // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) static long constexpr kTileSize = 256 * 16 * 8; @@ -188,7 +194,7 @@ struct DeviceComms { } template - void allreduce(T const* A, T* B, int N, bool quantized, hipStream_t stream) { + void allreduce(T const* A, T* B, int N, int quant_level, hipStream_t stream) { if (world_size != 2 && world_size != 4 && world_size != 8) { throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); @@ -208,11 +214,17 @@ struct DeviceComms { } else { unsigned long num_blocks = divceil(msg_size, kTileSize); unsigned long grid = min(kMaxNumBlocks, num_blocks); - - if (quantized) { - TWOSHOT_DISPATCH(CodecQ4Symm) - } else { - TWOSHOT_DISPATCH(CodecFP16) + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; } } HIP_CHECK(cudaGetLastError()); diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 38ba5367bf5b..92722c56ddfd 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -17,7 +17,7 @@ struct CodecBase { // Default full precision codec. template -struct CodecFP16 : public CodecBase { +struct CodecFP : public CodecBase { static constexpr int kWorldSize = world_size; static constexpr int kRankAtoms = kAtoms / kWorldSize; @@ -32,12 +32,11 @@ struct CodecFP16 : public CodecBase { static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; - __quickreduce_device_inline__ CodecFP16(int thread, int rank) + __quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - const int32x4_t* __restrict__ data, - const int* validity) { + const int32x4_t* __restrict__ data) { for (int i = 0; i < kRankAtoms; i++) { __builtin_nontemporal_store(data[i], send_buffer + thread); send_buffer += kAtomStride; @@ -57,7 +56,166 @@ struct CodecFP16 : public CodecBase { // We quantize the FP16 data to block-scaled Int4 in blocks of 4 * // kThreadGroupSize. template -struct CodecQ4Symm : public CodecBase { +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 2176; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static int constexpr kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static int constexpr kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static int constexpr kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static int constexpr kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) { + set_fp16_ovfl(true); + } + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. @@ -98,12 +256,13 @@ struct CodecQ4Symm : public CodecBase { // {+8, +8}, int16x2_t static constexpr int kRangeBias = 0x00080008; - __quickreduce_device_inline__ CodecQ4Symm(int thread, int rank) - : CodecBase(thread, rank) {} + __quickreduce_device_inline__ CodecQ4(int thread, int rank) + : CodecBase(thread, rank) { + set_fp16_ovfl(true); + } __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - const int32x4_t* __restrict__ data, - const int* validity) { + const int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; @@ -171,7 +330,7 @@ struct CodecQ4Symm : public CodecBase { *recv_buffer += kRankBufferTileStride; - // Unpack q4 into fp16x8_t + // Unpack q4 into f16x8_t int32x4_t w; { static constexpr uint kMask000F = 0x000F000F; @@ -183,9 +342,7 @@ struct CodecQ4Symm : public CodecBase { for (int i = 0; i < 4; i++) { if constexpr (std::is_same::value) { int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; - asm volatile("v_pk_add_f16 %0, %1, %2" - : "=v"(w[i]) - : "v"(q4), "v"(kHalf2_1032)); + packed_add(w[i], kHalf2_1032); } else { int32_t int16_2 = (qw >> (i * 4)) & kMask000F; int16_t low = static_cast(int16_2 & 0xFFFF); @@ -337,7 +494,6 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; - int tA_validity[kAtoms]; BufferResource src_buffer(const_cast(input), N * sizeof(T)); int src_offset = block * kTileSize + thread * sizeof(int32x4_t); @@ -345,7 +501,6 @@ struct AllReduceTwoshot { for (int i = 0; i < kAtoms; i++) { tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); - tA_validity[i] = src_offset < N * sizeof(T); src_offset += kAtomStride * sizeof(int32x4_t); } @@ -364,8 +519,7 @@ struct AllReduceTwoshot { int32x4_t* send_buffer = reinterpret_cast(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize); - codec.send(send_buffer, &tA[r * Codec::kRankAtoms], - &tA_validity[r * Codec::kRankAtoms]); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); } __syncthreads(); @@ -405,7 +559,7 @@ struct AllReduceTwoshot { int32x4_t* send_buffer = reinterpret_cast(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize); - codec.send(send_buffer, tR, &tA_validity[rank * Codec::kRankAtoms]); + codec.send(send_buffer, tR); } __syncthreads(); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0e8ab36ad98c..d14d30536306 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -728,7 +728,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { #ifdef USE_ROCM // Quick Reduce all-reduce kernels custom_ar.def( - "qr_all_reduce(int fa, Tensor inp, Tensor out, bool quantized) -> ()"); + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level) -> ()"); custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); custom_ar.def("init_custom_qr", &init_custom_qr); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e4965e9516fb..dc557349dbad 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1768,8 +1768,8 @@ def qr_destroy(fa: int) -> None: def qr_all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, - quantized: bool) -> None: - torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quantized) + quant_level: int) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level) def qr_get_handle(fa: int) -> torch.Tensor: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index f3f9e934c846..4d4ba4a8b4b6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -56,8 +56,7 @@ def __init__(self, if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce(group=self.cpu_group, - device=self.device, - max_size=8192 * 1024 * 2) + device=self.device) self.qr_comm: Optional[QuickAllReduce] = None if (use_custom_allreduce and current_platform.is_rocm() diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index ff400fc8edc7..69cea1c4843d 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from enum import Enum from typing import Union import torch @@ -20,6 +21,12 @@ ops_available = False +class QuickReduceQuantLevel(Enum): + FP = 0 + INT8 = 1 + INT4 = 2 + + class QuickAllReduce: _SUPPORTED_WORLD_SIZES = [2, 4, 8] _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -36,8 +43,12 @@ def __init__(self, group: ProcessGroup, self.max_size = ops.qr_max_size() self.group = group - self.quantized = envs.VLLM_ROCM_CA_QUANTIZED - + quant_level_str = envs.VLLM_ROCM_CA_QUANT_LEVEL + assert quant_level_str in QuickReduceQuantLevel.__members__, ( + f"Invalid quantization level: {quant_level_str}. " + "Supported levels: " + f"{list(QuickReduceQuantLevel.__members__.keys())}") + self.quant_level = QuickReduceQuantLevel[quant_level_str] # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 @@ -99,7 +110,7 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - ops.qr_all_reduce(self._ptr, inp, out, self.quantized) + ops.qr_all_reduce(self._ptr, inp, out, self.quant_level.value) return out.to(inp_dtype) def close(self): diff --git a/vllm/envs.py b/vllm/envs.py index 8dca6c47c0d3..b6195a1f649e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,7 +129,7 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None - VLLM_ROCM_CA_QUANTIZED: bool = False + VLLM_ROCM_CA_QUANT_LEVEL: str = "FP" VLLM_ROCM_CA_CAST_BF16_TO_FP16: bool = True @@ -674,19 +674,17 @@ def get_vllm_port() -> Optional[int]: ("true", "1")), # Custom quick allreduce kernel for MI3* cards. - # Whether to use quantization to speed up allreduce. - # Recommended for large models to get doubled allreduce - # speedup with less precision loss. - "VLLM_ROCM_CA_QUANTIZED": - lambda: (os.getenv("VLLM_ROCM_CA_QUANTIZED", "False").lower() in - ("true", "1")), + # Choice of quantization level: FP16, INT8, INT4 + # Recommended for large models to get allreduce + "VLLM_ROCM_CA_QUANT_LEVEL": + lambda: os.getenv("VLLM_ROCM_CA_QUANT_LEVEL", "FP").upper(), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, # If environment is not set to 1, we convert input to fp16 "VLLM_ROCM_CA_CAST_BF16_TO_FP16": - lambda: (os.getenv("VLLM_ROCM_CA_QUANTIZED", "True").lower() in + lambda: (os.getenv("VLLM_ROCM_CA_CAST_BF16_TO_FP16", "True").lower() in ("true", "1")), # If set, when running in Quark emulation mode, do not dequantize the From 280c0fc3c9e1f6365d4640934e8a78be8c148245 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 13 Jun 2025 14:04:53 +0000 Subject: [PATCH 14/28] Update after review comments Signed-off-by: ilmarkov --- csrc/custom_quickreduce.cu | 2 +- csrc/quickreduce/base.h | 6 ++- csrc/quickreduce/quick_reduce.h | 51 ++++++++++--------- csrc/quickreduce/quick_reduce_impl.cuh | 43 ++++++++-------- .../device_communicators/quick_all_reduce.py | 17 ++++--- vllm/envs.py | 8 +-- 6 files changed, 69 insertions(+), 58 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 79bc897c527b..97ebfbaa0a13 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -59,7 +59,7 @@ void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); - + TORCH_CHECK_LE(out.numel(), quickreduce::DeviceComms::kMaxProblemSize); if (out.scalar_type() == at::ScalarType::Half) { fa->allreduce(reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel(), diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 5eb0723e5863..7abbc18aea73 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -423,11 +423,13 @@ __quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, wblockmin = __shfl(wmin, group_leader); } -__quickreduce_device_inline__ void set_sync_flag(int* flag_ptr, int flag) { +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); } -__quickreduce_device_inline__ void wait_sync_flag(int* flag_ptr, int flag) { +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { } } diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index e45cc4f7758a..2cd5ecf25657 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -18,10 +18,10 @@ namespace quickreduce { using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -static constexpr int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; -static constexpr int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; -static constexpr int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; -static constexpr long kOneShotAllreduceMaxSize = +static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; +static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; +static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; +static constexpr unsigned int kOneShotAllreduceMaxSize = std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2, std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4, kOneShotAllreduceMaxElemsWorldSize8 * 8)) * @@ -29,17 +29,17 @@ static constexpr long kOneShotAllreduceMaxSize = template __global__ __quickreduce_launch_bounds_one_shot__ static void -allreduce_prototype_oneshot(T const* A, T* B, int N, int rank, - uint8_t** dbuffer_list, long data_offset, - int flag_color) { +allreduce_prototype_oneshot(T const* A, T* B, uint32_t N, int rank, + uint8_t** dbuffer_list, uint32_t data_offset, + uint32_t flag_color) { AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color); } template __global__ __quickreduce_launch_bounds_two_shot__ static void -allreduce_prototype_twoshot(T const* A, T* B, int N, int num_blocks, int rank, - uint8_t** dbuffer_list, long data_offset, - int flag_color) { +allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, + int rank, uint8_t** dbuffer_list, + uint32_t data_offset, uint32_t flag_color) { int block = blockIdx.x; int grid = gridDim.x; @@ -102,15 +102,15 @@ struct DeviceComms { // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) static long constexpr kTileSize = 256 * 16 * 8; - // Max problem size is 8GB (in bytes) - static long constexpr kMaxProblemSize = 8589934592; - static long constexpr kMaxTiles = kMaxProblemSize / kTileSize; + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + static int64_t constexpr kMaxProblemSize = 2147483647; + static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize; // Max TP-8 static int constexpr kMaxWorldSize = 8; bool initialized = false; - int flag_color = 1; + uint32_t flag_color = 1; int world_size; int rank; @@ -119,7 +119,7 @@ struct DeviceComms { hipIpcMemHandle_t buffer_ipc_handle; std::vector all_buffer_ipc_handles; std::vector buffer_list; - long data_offset; + uint32_t data_offset; DeviceComms() : initialized(false), world_size(1), rank(0) {} ~DeviceComms() { destroy(); } @@ -130,10 +130,10 @@ struct DeviceComms { this->rank = rank; // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. - long flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - static constexpr long data_buffer_size = - std::max(2 * kMaxProblemSize, kOneShotAllreduceMaxSize); - long total_buffer_size = flags_buffer_size + data_buffer_size; + uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); + static constexpr int64_t data_buffer_size = std::max( + 2 * kMaxProblemSize, static_cast(kOneShotAllreduceMaxSize)); + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); @@ -194,26 +194,27 @@ struct DeviceComms { } template - void allreduce(T const* A, T* B, int N, int quant_level, hipStream_t stream) { + void allreduce(T const* A, T* B, uint32_t N, int quant_level, + hipStream_t stream) { if (world_size != 2 && world_size != 4 && world_size != 8) { throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); } // Configuration. - long msg_size = N * sizeof(T); + uint32_t msg_size = N * sizeof(T); bool use_one_shot_allreduce = (world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or (world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or (world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8); if (use_one_shot_allreduce) { // Each thread processes blocks out of 4 elements - unsigned long num_blocks = divceil(N, (4 * kThreadsOneShot)); - unsigned long grid = min(kMaxNumBlocks, num_blocks); + uint64_t num_blocks = divceil(N, (4 * kThreadsOneShot)); + uint64_t grid = min(kMaxNumBlocks, num_blocks); ONESHOT_DISPATCH() } else { - unsigned long num_blocks = divceil(msg_size, kTileSize); - unsigned long grid = min(kMaxNumBlocks, num_blocks); + uint64_t num_blocks = divceil(msg_size, kTileSize); + uint64_t grid = min(kMaxNumBlocks, num_blocks); auto quant_level_ = static_cast(quant_level); switch (quant_level_) { case QuickReduceQuantLevel::INT8: diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 92722c56ddfd..813f9bd620fa 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -374,11 +374,11 @@ struct AllReduceOneshot { __device__ static void run( T const* __restrict__ A, // input T* __restrict__ B, // output - int const N, // number of elements - int const rank, // rank index + uint32_t const N, // number of elements + uint32_t const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers long const data_offset, // offset to start of the data buffer - int flag_color) { + uint32_t flag_color) { BufferResource src_buffer(const_cast(A), N * sizeof(T)); BufferResource dst_buffer(B, N * sizeof(T)); @@ -387,7 +387,7 @@ struct AllReduceOneshot { const int block_size = blockDim.x; const int thread = threadIdx.x; const int block = blockIdx.x; - const int problem_size = (N + 3) / 4; + const uint32_t problem_size = (N + 3) / 4; int32x4_t tA, tB; long grid = gridDim.x; @@ -479,13 +479,13 @@ struct AllReduceTwoshot { __device__ static void run( T const* __restrict__ input, T* __restrict__ output, - int const N, // number of elements + uint32_t const N, // number of elements int const block, // block index int const num_blocks, // number of blocks int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - int flag_color) { + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { // Topology int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; @@ -496,7 +496,7 @@ struct AllReduceTwoshot { int32x4_t tA[kAtoms]; BufferResource src_buffer(const_cast(input), N * sizeof(T)); - int src_offset = block * kTileSize + thread * sizeof(int32x4_t); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); int32x4_t* src = reinterpret_cast(const_cast(input)); for (int i = 0; i < kAtoms; i++) { @@ -507,13 +507,14 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Phase-1A: Write segment data into the communication buffer of the target // rank responsible for this segment. - long comm_data0_offset = data_offset + block * Codec::kTransmittedTileSize; - long comm_data1_offset = + uint32_t comm_data0_offset = + data_offset + block * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = num_blocks * Codec::kTransmittedTileSize + comm_data0_offset; - long comm_flags0_offset = block * (kWorldSize * sizeof(int)); - long comm_flags1_offset = - num_blocks * (kWorldSize * sizeof(int)) + comm_flags0_offset; + uint32_t comm_flags0_offset = block * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = + num_blocks * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { int32x4_t* send_buffer = @@ -525,8 +526,8 @@ struct AllReduceTwoshot { __syncthreads(); if (thread < kWorldSize) { int r = thread; - int* flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); set_sync_flag(flag_ptr, flag_color); } // -------------------------------------------------------- @@ -536,7 +537,8 @@ struct AllReduceTwoshot { // Read the data from the communication buffer. int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data0_offset); - int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags0_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags0_offset); for (int r = 0; r < kWorldSize; r++) { // Wait for the flags to be set. @@ -565,8 +567,8 @@ struct AllReduceTwoshot { __syncthreads(); if (thread < kWorldSize) { int r = thread; - int* flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); set_sync_flag(flag_ptr, flag_color); } @@ -575,7 +577,8 @@ struct AllReduceTwoshot { // Read the data from the communication buffer. int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data1_offset); - int* flag_ptr = reinterpret_cast(rank_buffer + comm_flags1_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags1_offset); for (int r = 0; r < kWorldSize; r++) { // Wait for the flags to be set. @@ -592,7 +595,7 @@ struct AllReduceTwoshot { // -------------------------------------------------------- // Write the result to output. BufferResource dst_buffer(output, N * sizeof(T)); - int dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); int32x4_t* dst = reinterpret_cast(output); for (int i = 0; i < kAtoms; i++) { diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 69cea1c4843d..e6b5debc0184 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -21,10 +21,11 @@ ops_available = False -class QuickReduceQuantLevel(Enum): +class QuickReduceRegime(Enum): FP = 0 INT8 = 1 INT4 = 2 + NONE = 3 class QuickAllReduce: @@ -43,12 +44,16 @@ def __init__(self, group: ProcessGroup, self.max_size = ops.qr_max_size() self.group = group - quant_level_str = envs.VLLM_ROCM_CA_QUANT_LEVEL - assert quant_level_str in QuickReduceQuantLevel.__members__, ( - f"Invalid quantization level: {quant_level_str}. " + regime_str = envs.VLLM_ROCM_CA_QUANT_REGIME + assert regime_str in QuickReduceRegime.__members__, ( + f"Invalid quantization level: {regime_str}. " "Supported levels: " - f"{list(QuickReduceQuantLevel.__members__.keys())}") - self.quant_level = QuickReduceQuantLevel[quant_level_str] + f"{list(QuickReduceRegime.__members__.keys())}") + if regime_str == "NONE": + logger.debug("Custom quickreduce is disabled based on " + "env variable VLLM_ROCM_CA_QUANT_REGIME") + return + self.quant_level = QuickReduceRegime[regime_str] # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 diff --git a/vllm/envs.py b/vllm/envs.py index b6195a1f649e..1c5c842ed49b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,7 +129,7 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None - VLLM_ROCM_CA_QUANT_LEVEL: str = "FP" + VLLM_ROCM_CA_QUANT_REGIME: str = "FP" VLLM_ROCM_CA_CAST_BF16_TO_FP16: bool = True @@ -674,10 +674,10 @@ def get_vllm_port() -> Optional[int]: ("true", "1")), # Custom quick allreduce kernel for MI3* cards. - # Choice of quantization level: FP16, INT8, INT4 + # Choice of quantization level: FP, INT8, INT4 or NONE # Recommended for large models to get allreduce - "VLLM_ROCM_CA_QUANT_LEVEL": - lambda: os.getenv("VLLM_ROCM_CA_QUANT_LEVEL", "FP").upper(), + "VLLM_ROCM_CA_QUANT_REGIME": + lambda: os.getenv("VLLM_ROCM_CA_QUANT_REGIME", "FP").upper(), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 From c80ab57e20b4bb6d45a6bfee05fca80603862e9f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Jun 2025 16:07:13 +0000 Subject: [PATCH 15/28] add Q6 support Signed-off-by: Haoyang Li --- csrc/quickreduce/base.h | 252 +++++---- csrc/quickreduce/quick_reduce.h | 8 +- csrc/quickreduce/quick_reduce_impl.cuh | 511 ++++++++++++------ .../device_communicators/quick_all_reduce.py | 9 +- vllm/envs.py | 4 +- 5 files changed, 494 insertions(+), 290 deletions(-) diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 7abbc18aea73..eeeae078a404 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -104,6 +104,10 @@ __quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { } #endif } +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; template __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, @@ -152,10 +156,11 @@ __quickreduce_device_inline__ int packed_max(int a, int b) { template <> __quickreduce_device_inline__ int packed_max(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hmax2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; } template @@ -170,10 +175,11 @@ __quickreduce_device_inline__ int packed_min(int a, int b) { template <> __quickreduce_device_inline__ int packed_min(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hmin2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; } template @@ -194,15 +200,12 @@ __quickreduce_device_inline__ int packed_abs_max(int a, int b) { template <> __quickreduce_device_inline__ int packed_abs_max(int a, int b) { - nv_bfloat162 wmaxh2 = *(reinterpret_cast(&a)); - nv_bfloat162 wminh2 = *(reinterpret_cast(&b)); - nv_bfloat162 wblockmaxh2; - wblockmaxh2.x = - __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; - wblockmaxh2.y = - __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; - - return *(reinterpret_cast(&wblockmaxh2)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + R.bf2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return R.i; } template @@ -217,10 +220,11 @@ __quickreduce_device_inline__ int packed_add(int a, int b) { template <> __quickreduce_device_inline__ int packed_add(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hadd2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; } template <> @@ -246,10 +250,11 @@ __quickreduce_device_inline__ int packed_sub(int a, int b) { template <> __quickreduce_device_inline__ int packed_sub(int a, int b) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162* tB = reinterpret_cast(&b); - nv_bfloat162 tR = __hsub2(*tA, *tB); - return *(reinterpret_cast(&tR)); + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; } template @@ -280,77 +285,89 @@ __quickreduce_device_inline__ int packed_rcp(int a) { template <> __quickreduce_device_inline__ int packed_rcp(int a) { - nv_bfloat162* tA = reinterpret_cast(&a); - nv_bfloat162 tR = h2rcp(*tA); - return *(reinterpret_cast(&tR)); + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; } -template -__quickreduce_device_inline__ T float2T_cast(float a); - -template <> -__quickreduce_device_inline__ half float2T_cast(float a) { - return __float2half(a); +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); } -template <> -__quickreduce_device_inline__ nv_bfloat16 float2T_cast(float a) { - return __float2bfloat16(a); +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); } -template -__quickreduce_device_inline__ float T2float_cast(T a); +// template +// __quickreduce_device_inline__ float T2float_cast(T a); -template <> -__quickreduce_device_inline__ float T2float_cast(half a) { - return __half2float(a); -} +// template <> +// __quickreduce_device_inline__ float T2float_cast(half a) { +// return __half2float(a); +// } -template <> -__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { - return __bfloat162float(a); -} +// template <> +// __quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) +// { +// return __bfloat162float(a); +// } -template -__quickreduce_device_inline__ unsigned char T2uchar_cast(T a); +// template +// __quickreduce_device_inline__ T float2T_cast(float a); -template <> -__quickreduce_device_inline__ unsigned char T2uchar_cast(half a) { - return static_cast(__half2ushort_rz(a)); -} +// template <> +// __quickreduce_device_inline__ half float2T_cast(float a) { +// return __float2half(a); +// } -template <> -__quickreduce_device_inline__ unsigned char T2uchar_cast( - nv_bfloat16 a) { - return static_cast(__bfloat16_as_ushort(a)); -} +// template <> +// __quickreduce_device_inline__ nv_bfloat16 float2T_cast(float a) +// { +// return __float2bfloat16(a); +// } -template -__quickreduce_device_inline__ T uchar2T_cast(unsigned char a); +// template +// __quickreduce_device_inline__ unsigned char T2uchar_cast(T a); -template <> -__quickreduce_device_inline__ half uchar2T_cast(unsigned char a) { - return __ushort2half_rz(static_cast(a)); -} +// template <> +// __quickreduce_device_inline__ unsigned char T2uchar_cast(half a) { +// return static_cast(__half2ushort_rz(a)); +// } -template <> -__quickreduce_device_inline__ nv_bfloat16 -uchar2T_cast(unsigned char a) { - return __ushort_as_bfloat16(static_cast(a)); -} +// template <> +// __quickreduce_device_inline__ unsigned char T2uchar_cast( +// nv_bfloat16 a) { +// return static_cast(__bfloat16_as_ushort(a)); +// } -template -__quickreduce_device_inline__ int T2int_cast(T a); +// template +// __quickreduce_device_inline__ T uchar2T_cast(unsigned char a); -template <> -__quickreduce_device_inline__ int T2int_cast(half a) { - return __half2int_rz(a); -} +// template <> +// __quickreduce_device_inline__ half uchar2T_cast(unsigned char a) { +// return __ushort2half_rz(static_cast(a)); +// } -template <> -__quickreduce_device_inline__ int T2int_cast(nv_bfloat16 a) { - return static_cast(__bfloat16_as_ushort(a)); -} +// template <> +// __quickreduce_device_inline__ nv_bfloat16 +// uchar2T_cast(unsigned char a) { +// return __ushort_as_bfloat16(static_cast(a)); +// } + +// template +// __quickreduce_device_inline__ int T2int_cast(T a); + +// template <> +// __quickreduce_device_inline__ int T2int_cast(half a) { +// return __half2int_rz(a); +// } + +// template <> +// __quickreduce_device_inline__ int T2int_cast(nv_bfloat16 a) { +// return static_cast(__bfloat16_as_ushort(a)); +// } template __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { @@ -384,44 +401,45 @@ __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { return wblockmax; } -template -__quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& wblockmax, - int& wblockmin, - int valid_data) { - const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; - static constexpr int FP_MAX = - std::is_same::value ? 0x7BFF7BFF : 0x7F7F7F7F; - static constexpr int FP_MIN = - std::is_same::value ? 0xFBFFFBFF : 0xFF7FFF7F; - - int wmax, wmin; - int a, b; - a = packed_max(atom[0], atom[1]); - b = packed_max(atom[2], atom[3]); - // In case the data was loaded out of range (and initialized to 0) - // we set max min values to sentinel values - // so that they do not spoil the group max min values - wmax = valid_data * packed_max(a, b) + (!valid_data) * FP_MIN; - - a = packed_min(atom[0], atom[1]); - b = packed_min(atom[2], atom[3]); - wmin = valid_data * packed_min(a, b) + (!valid_data) * FP_MAX; - - // Reduce the max and min among a group of threads - // Note: This is basically 2 blocks of values setup as the - // upper/lower halves of the f16x2_t - for (int i = 1; i < kThreadGroupSize; i <<= 1) { - int x = __shfl_down(wmax, i); - wmax = packed_max(wmax, x); - - int y = __shfl_down(wmin, i); - wmin = packed_min(wmin, y); - } - - // Share with the cohort - wblockmax = __shfl(wmax, group_leader); - wblockmin = __shfl(wmin, group_leader); -} +// template +// __quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& +// wblockmax, +// int& wblockmin, +// int valid_data) { +// const int group_leader = (threadIdx.x / kThreadGroupSize) * +// kThreadGroupSize; static constexpr int FP_MAX = +// std::is_same::value ? 0x7BFF7BFF : 0x7F7F7F7F; +// static constexpr int FP_MIN = +// std::is_same::value ? 0xFBFFFBFF : 0xFF7FFF7F; + +// int wmax, wmin; +// int a, b; +// a = packed_max(atom[0], atom[1]); +// b = packed_max(atom[2], atom[3]); +// // In case the data was loaded out of range (and initialized to 0) +// // we set max min values to sentinel values +// // so that they do not spoil the group max min values +// wmax = valid_data * packed_max(a, b) + (!valid_data) * FP_MIN; + +// a = packed_min(atom[0], atom[1]); +// b = packed_min(atom[2], atom[3]); +// wmin = valid_data * packed_min(a, b) + (!valid_data) * FP_MAX; + +// // Reduce the max and min among a group of threads +// // Note: This is basically 2 blocks of values setup as the +// // upper/lower halves of the f16x2_t +// for (int i = 1; i < kThreadGroupSize; i <<= 1) { +// int x = __shfl_down(wmax, i); +// wmax = packed_max(wmax, x); + +// int y = __shfl_down(wmin, i); +// wmin = packed_min(wmin, y); +// } + +// // Share with the cohort +// wblockmax = __shfl(wmax, group_leader); +// wblockmin = __shfl(wmin, group_leader); +// } __quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) { diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 2cd5ecf25657..b5c21b40ddb8 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -94,8 +94,9 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, enum QuickReduceQuantLevel { FP16 = 0, - INT8, - INT4, + INT8 = 1, + int6 = 2, + INT4 = 3, }; struct DeviceComms { @@ -220,6 +221,9 @@ struct DeviceComms { case QuickReduceQuantLevel::INT8: TWOSHOT_DISPATCH(CodecQ8) break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; case QuickReduceQuantLevel::INT4: TWOSHOT_DISPATCH(CodecQ4) break; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 813f9bd620fa..0228806cf5bb 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -52,59 +52,62 @@ struct CodecFP : public CodecBase { } }; +// MARK: Q4 Line Codec // Int4 symmetric quantization codec. // We quantize the FP16 data to block-scaled Int4 in blocks of 4 * // kThreadGroupSize. template -struct CodecQ8 : public CodecBase { +struct CodecQ4 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. - // Each threads processes a fragment of f16x8_t (16B), - // into a int8x8_t (8B) and a f16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 2176; - static int constexpr kRankTileScaleOffset = 2048; - static int constexpr kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, - "kRankTileSize must be 16B aligned."); + "kRankTransmittedTileSize must be 16B aligned."); - static int constexpr kRankBufferTileStride = + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTransmittedTileSize = + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; // Constants configuration - // {-1/128.0h, -1/128.0h}, f16x2_t - static int constexpr kScaleFactor = - std::is_same::value ? 0xA000A000 : 0xBC00BC00; + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xB000B000 : 0xBE00BE00; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-128, -128}, f16x2_t - static int constexpr kRangeMin = - std::is_same::value ? 0xD800D800 : 0xC300C300; - // {+127, +127}, f16x2_t - static int constexpr kRangeMax = - std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; - // {+128, +128}, int16x2_t - static int constexpr kRangeBias = 0x00800080; + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; - __quickreduce_device_inline__ CodecQ8(int thread, int rank) + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) { set_fp16_ovfl(true); } __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, - int32x4_t const* __restrict__ data) { + const int32x4_t* __restrict__ data) { for (int k = 0; k < kRankAtoms; k++) { int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group // In 2 blocks of values, upper/lower halves of the f16x2_t int wblockmax = group_abs_max(atom); @@ -129,24 +132,21 @@ struct CodecQ8 : public CodecBase { { int16_t* qi = reinterpret_cast(&q); T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { q[i] = packed_add(q[i], kRangeBias); } } - // Pack 8 x q8 into int32x2_t - int32x2_t qw; - qw[0] = q[0] | (q[1] << 8); - qw[1] = q[2] | (q[3] << 8); + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); // Write quantized atom to send_buffer // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); @@ -162,34 +162,30 @@ struct CodecQ8 : public CodecBase { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int32_t qw = __builtin_nontemporal_load(qw_ptr); int qs = __builtin_nontemporal_load(qs_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q8 into fp16x8_t + // Unpack q4 into f16x8_t int32x4_t w; { - static uint constexpr kMask00FF = 0x00FF00FF; - - // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1024 = 0x64006400; - - // {-1152.0, -1152.0}, fp16x2_t - static uint constexpr kHalf2_1152 = 0xE480E480; + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t -#pragma unroll for (int i = 0; i < 4; i++) { if constexpr (std::is_same::value) { - int32_t q8 = - ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; - w[i] = packed_add(q8, kHalf2_1152); + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + packed_add(w[i], kHalf2_1032); } else { - int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); @@ -211,19 +207,21 @@ struct CodecQ8 : public CodecBase { } }; -// Int4 symmetric quantization codec. -// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// MARK: Q6 Line Codec +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * // kThreadGroupSize. template -struct CodecQ4 : public CodecBase { +struct CodecQ6 : public CodecBase { static constexpr int kWorldSize = world_size; // Codec tile size process by this workgroup. // Each threads processes a fragment of fp16x8_t (16B), - // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. static constexpr int kRankAtoms = kAtoms / kWorldSize; - static constexpr int kRankTileStride = 1152; - static constexpr int kRankTileScaleOffset = 1024; + static int constexpr kRankTileStride = 1664; + static int constexpr kRankTileQ2Offset = 1024; + static int constexpr kRankTileScaleOffset = 1536; static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); @@ -237,29 +235,27 @@ struct CodecQ4 : public CodecBase { // Constants configuration - // {-1/8.0h, -1/8.0h}, f16x2_t + // {-1/32.0h, -1/32.0h}, fp16x2_t static constexpr int kScaleFactor = - std::is_same::value ? 0xB000B000 : 0xBE00BE00; + std::is_same::value ? 0xA800A800 : 0xBD00BD00; - // {1e-7, 1e-7}, f16x2_t + // {1e-7, 1e-7}, fp16x2_t static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; - // {-8, -8}, f16x2_t + // {-32, -32}, fp16x2_t static constexpr int kRangeMin = - std::is_same::value ? 0xC800C800 : 0xC100C100; + std::is_same::value ? 0xD000D000 : 0xC200C200; - // {+7, +7}, f16x2_t + // {+31, +31}, fp16x2_t static constexpr int kRangeMax = - std::is_same::value ? 0x47004700 : 0x40E040E0; + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; - // {+8, +8}, int16x2_t - static constexpr int kRangeBias = 0x00080008; + // {+32, +32}, int16x2_t + static int constexpr kRangeBias = 0x00200020; - __quickreduce_device_inline__ CodecQ4(int thread, int rank) - : CodecBase(thread, rank) { - set_fp16_ovfl(true); - } + __quickreduce_device_inline__ CodecQ6(int thread, int rank) + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { @@ -290,26 +286,37 @@ struct CodecQ4 : public CodecBase { { int16_t* qi = reinterpret_cast(&q); T* wh = reinterpret_cast(&w); - for (int i = 0; i < 8; i++) - qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); for (int i = 0; i < 4; i++) { q[i] = packed_add(q[i], kRangeBias); } } - // Pack 8 x q4 into int32_t - int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); - + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } // Write quantized atom to send_buffer // note: only the group leader stores the scale uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - __builtin_nontemporal_store(qw, qw_ptr); + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); if (threadIdx.x == group_leader) { __builtin_nontemporal_store(decoding_scale, qs_ptr); } @@ -321,155 +328,224 @@ struct CodecQ4 : public CodecBase { for (int k = 0; k < kRankAtoms; k++) { // Directly read quantized atom from recv_buffer uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); - int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); - int32_t qw = __builtin_nontemporal_load(qw_ptr); + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); int qs = __builtin_nontemporal_load(qs_ptr); *recv_buffer += kRankBufferTileStride; - // Unpack q4 into f16x8_t + // Unpack q6 into fp16x8_t int32x4_t w; { - static constexpr uint kMask000F = 0x000F000F; - static constexpr uint kHalf2_1024 = + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t - static uint constexpr kHalf2_1032 = - 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t +#pragma unroll for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; if constexpr (std::is_same::value) { - int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; - packed_add(w[i], kHalf2_1032); + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); } else { - int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int32_t int16_2 = q4 | (q2 << 4); int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); int32_t packed_bf16 = *reinterpret_cast(&bf2); - w[i] = packed_add(packed_bf16, kRangeMin); + w[i] = pk_add(packed_bf16, kRangeMin); } } } // Apply decoding scales for (int i = 0; i < 4; i++) { - w[i] = packed_mul(w[i], qs); + w[i] = pk_mul(w[i], qs); } + // That's pretty much it... data[k] = w; } } }; -// Oneshot AllReduce +// MARK: Q8 Line Codec +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. template -struct AllReduceOneshot { - static_assert(sizeof(T) == 2); +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; - __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - uint32_t const N, // number of elements - uint32_t const rank, // rank index - uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - uint32_t flag_color) { - BufferResource src_buffer(const_cast(A), N * sizeof(T)); - BufferResource dst_buffer(B, N * sizeof(T)); + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static int constexpr kRankAtoms = kAtoms / kWorldSize; + static int constexpr kRankTileStride = 2176; + static int constexpr kRankTileScaleOffset = 2048; + static int constexpr kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); - uint8_t* rank_buffer = buffer_list[rank]; + static int constexpr kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); - const int block_size = blockDim.x; - const int thread = threadIdx.x; - const int block = blockIdx.x; - const uint32_t problem_size = (N + 3) / 4; + // Total tile size for the collective communication. + static int constexpr kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; - int32x4_t tA, tB; - long grid = gridDim.x; - long data_stride = grid * block_size * sizeof(int32x4_t); - long comm_flags0_offset = block * (world_size * sizeof(int)); - long comm_flags1_offset = - comm_flags0_offset + grid * (world_size * sizeof(int)); + // Constants configuration - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - // load values - tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); + // {-1/128.0h, -1/128.0h}, f16x2_t + static int constexpr kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; - // Write rank data into this rank segment of every rank's communication - // buffer. -#pragma unroll - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + data_offset + rank * data_stride + - idx * sizeof(int32x4_t)); - __builtin_nontemporal_store(tA, send_buffer); - } - } + // {1e-7, 1e-7}, f16x2_t + static int constexpr kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags0_offset + r * sizeof(int)); + // {-128, -128}, f16x2_t + static int constexpr kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static int constexpr kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { + // {+128, +128}, int16x2_t + static int constexpr kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) { + set_fp16_ovfl(true); + } + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); } - } - __syncthreads(); - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { + // Convert from f16x2_t to uint16x2_t + int32x4_t q; { - int r = 0; - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tA = __builtin_nontemporal_load(recv_buffer); - } -#pragma unroll - for (int r = 1; r < world_size; r++) { - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tB = __builtin_nontemporal_load(recv_buffer); + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); - // Reduce the local data with the read data - packed_assign_add(&tA, &tB); + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } } - buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } } + } - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags1_offset + r * sizeof(int)); + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; } } }; +// MARK: Twoshot All Reduce // Twoshot All Reduce template struct AllReduceTwoshot { @@ -605,4 +681,109 @@ struct AllReduceTwoshot { } }; +// MARK: Oneshot All Reduce +// Oneshot AllReduce +template +struct AllReduceOneshot { + static_assert(sizeof(T) == 2); + + __device__ static void run( + T const* __restrict__ A, // input + T* __restrict__ B, // output + uint32_t const N, // number of elements + uint32_t const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + long const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + BufferResource src_buffer(const_cast(A), N * sizeof(T)); + BufferResource dst_buffer(B, N * sizeof(T)); + + uint8_t* rank_buffer = buffer_list[rank]; + + const int block_size = blockDim.x; + const int thread = threadIdx.x; + const int block = blockIdx.x; + const uint32_t problem_size = (N + 3) / 4; + + int32x4_t tA, tB; + long grid = gridDim.x; + long data_stride = grid * block_size * sizeof(int32x4_t); + long comm_flags0_offset = block * (world_size * sizeof(int)); + long comm_flags1_offset = + comm_flags0_offset + grid * (world_size * sizeof(int)); + + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + // load values + tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); + + // Write rank data into this rank segment of every rank's communication + // buffer. +#pragma unroll + for (int r = 0; r < world_size; r++) { + int32x4_t* send_buffer = reinterpret_cast( + buffer_list[r] + data_offset + rank * data_stride + + idx * sizeof(int32x4_t)); + __builtin_nontemporal_store(tA, send_buffer); + } + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags0_offset + r * sizeof(int)); + + // Wait for the flags to be set. + while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { + } + } + __syncthreads(); + + for (int idx = block * block_size + thread; idx < problem_size; + idx += grid * block_size) { + { + int r = 0; + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tA = __builtin_nontemporal_load(recv_buffer); + } +#pragma unroll + for (int r = 1; r < world_size; r++) { + // Read posted data from the rank's communication buffer. + int32x4_t* recv_buffer = reinterpret_cast( + rank_buffer + data_offset + r * data_stride + + idx * sizeof(int32x4_t)); + tB = __builtin_nontemporal_load(recv_buffer); + + // Reduce the local data with the read data + packed_assign_add(&tA, &tB); + } + + buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), + 0, 0); + } + + __syncthreads(); + if (thread < world_size) { + int r = thread; + int* peer_flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); + __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); + int* self_flag_ptr = reinterpret_cast( + rank_buffer + comm_flags1_offset + r * sizeof(int)); + + // Wait for the flags to be set. + while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { + } + } + } +}; + } // namespace quickreduce \ No newline at end of file diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index e6b5debc0184..5f0887956ec2 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -24,8 +24,9 @@ class QuickReduceRegime(Enum): FP = 0 INT8 = 1 - INT4 = 2 - NONE = 3 + INT6 = 2 + INT4 = 3 + NONE = 4 class QuickAllReduce: @@ -50,8 +51,8 @@ def __init__(self, group: ProcessGroup, "Supported levels: " f"{list(QuickReduceRegime.__members__.keys())}") if regime_str == "NONE": - logger.debug("Custom quickreduce is disabled based on " - "env variable VLLM_ROCM_CA_QUANT_REGIME") + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_CA_QUANT_REGIME") return self.quant_level = QuickReduceRegime[regime_str] # On RocM bfloat16 kernels are slower than fp16 diff --git a/vllm/envs.py b/vllm/envs.py index 1c5c842ed49b..07c76ecf9063 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -673,8 +673,8 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), - # Custom quick allreduce kernel for MI3* cards. - # Choice of quantization level: FP, INT8, INT4 or NONE + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce "VLLM_ROCM_CA_QUANT_REGIME": lambda: os.getenv("VLLM_ROCM_CA_QUANT_REGIME", "FP").upper(), From 93c33f00448d2ad569f90809f7b148691deda5db Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Jun 2025 16:17:05 +0000 Subject: [PATCH 16/28] Adjusted to static constexpr int Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce_impl.cuh | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 0228806cf5bb..679795447b6c 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -219,9 +219,9 @@ struct CodecQ6 : public CodecBase { // Each threads processes a fragment of fp16x8_t (16B), // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. static constexpr int kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 1664; - static int constexpr kRankTileQ2Offset = 1024; - static int constexpr kRankTileScaleOffset = 1536; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); @@ -252,7 +252,7 @@ struct CodecQ6 : public CodecBase { std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; // {+32, +32}, int16x2_t - static int constexpr kRangeBias = 0x00200020; + static constexpr int kRangeBias = 0x00200020; __quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {} @@ -397,39 +397,39 @@ struct CodecQ8 : public CodecBase { // Codec tile size process by this workgroup. // Each threads processes a fragment of f16x8_t (16B), // into a int8x8_t (8B) and a f16 scale shared among 32 values. - static int constexpr kRankAtoms = kAtoms / kWorldSize; - static int constexpr kRankTileStride = 2176; - static int constexpr kRankTileScaleOffset = 2048; - static int constexpr kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned."); - static int constexpr kRankBufferTileStride = + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); // Total tile size for the collective communication. - static int constexpr kTransmittedTileSize = + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; // Constants configuration // {-1/128.0h, -1/128.0h}, f16x2_t - static int constexpr kScaleFactor = + static constexpr int kScaleFactor = std::is_same::value ? 0xA000A000 : 0xBC00BC00; // {1e-7, 1e-7}, f16x2_t - static int constexpr kScaleEpsilon = + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; // {-128, -128}, f16x2_t - static int constexpr kRangeMin = + static constexpr int kRangeMin = std::is_same::value ? 0xD800D800 : 0xC300C300; // {+127, +127}, f16x2_t - static int constexpr kRangeMax = + static constexpr int kRangeMax = std::is_same::value ? 0x57F057F0 : 0x42FE42FE; // {+128, +128}, int16x2_t - static int constexpr kRangeBias = 0x00800080; + static constexpr int kRangeBias = 0x00800080; __quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) { From 2b765cc824ee10b26e13c6dd69808000b55ec49c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Jun 2025 17:21:03 +0000 Subject: [PATCH 17/28] Remove useless functions Signed-off-by: Haoyang Li --- csrc/quickreduce/base.h | 113 +------------------------ csrc/quickreduce/quick_reduce.h | 2 +- csrc/quickreduce/quick_reduce_impl.cuh | 4 +- 3 files changed, 5 insertions(+), 114 deletions(-) diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index eeeae078a404..8138eda82875 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -203,8 +203,8 @@ __quickreduce_device_inline__ int packed_abs_max(int a, int b) { bf162_int_union A, B, R; A.i = a; B.i = b; - R.bf2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; - R.bf2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; return R.i; } @@ -300,75 +300,6 @@ __quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } -// template -// __quickreduce_device_inline__ float T2float_cast(T a); - -// template <> -// __quickreduce_device_inline__ float T2float_cast(half a) { -// return __half2float(a); -// } - -// template <> -// __quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) -// { -// return __bfloat162float(a); -// } - -// template -// __quickreduce_device_inline__ T float2T_cast(float a); - -// template <> -// __quickreduce_device_inline__ half float2T_cast(float a) { -// return __float2half(a); -// } - -// template <> -// __quickreduce_device_inline__ nv_bfloat16 float2T_cast(float a) -// { -// return __float2bfloat16(a); -// } - -// template -// __quickreduce_device_inline__ unsigned char T2uchar_cast(T a); - -// template <> -// __quickreduce_device_inline__ unsigned char T2uchar_cast(half a) { -// return static_cast(__half2ushort_rz(a)); -// } - -// template <> -// __quickreduce_device_inline__ unsigned char T2uchar_cast( -// nv_bfloat16 a) { -// return static_cast(__bfloat16_as_ushort(a)); -// } - -// template -// __quickreduce_device_inline__ T uchar2T_cast(unsigned char a); - -// template <> -// __quickreduce_device_inline__ half uchar2T_cast(unsigned char a) { -// return __ushort2half_rz(static_cast(a)); -// } - -// template <> -// __quickreduce_device_inline__ nv_bfloat16 -// uchar2T_cast(unsigned char a) { -// return __ushort_as_bfloat16(static_cast(a)); -// } - -// template -// __quickreduce_device_inline__ int T2int_cast(T a); - -// template <> -// __quickreduce_device_inline__ int T2int_cast(half a) { -// return __half2int_rz(a); -// } - -// template <> -// __quickreduce_device_inline__ int T2int_cast(nv_bfloat16 a) { -// return static_cast(__bfloat16_as_ushort(a)); -// } - template __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; @@ -401,46 +332,6 @@ __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { return wblockmax; } -// template -// __quickreduce_device_inline__ void group_max_min(int32x4_t atom, int& -// wblockmax, -// int& wblockmin, -// int valid_data) { -// const int group_leader = (threadIdx.x / kThreadGroupSize) * -// kThreadGroupSize; static constexpr int FP_MAX = -// std::is_same::value ? 0x7BFF7BFF : 0x7F7F7F7F; -// static constexpr int FP_MIN = -// std::is_same::value ? 0xFBFFFBFF : 0xFF7FFF7F; - -// int wmax, wmin; -// int a, b; -// a = packed_max(atom[0], atom[1]); -// b = packed_max(atom[2], atom[3]); -// // In case the data was loaded out of range (and initialized to 0) -// // we set max min values to sentinel values -// // so that they do not spoil the group max min values -// wmax = valid_data * packed_max(a, b) + (!valid_data) * FP_MIN; - -// a = packed_min(atom[0], atom[1]); -// b = packed_min(atom[2], atom[3]); -// wmin = valid_data * packed_min(a, b) + (!valid_data) * FP_MAX; - -// // Reduce the max and min among a group of threads -// // Note: This is basically 2 blocks of values setup as the -// // upper/lower halves of the f16x2_t -// for (int i = 1; i < kThreadGroupSize; i <<= 1) { -// int x = __shfl_down(wmax, i); -// wmax = packed_max(wmax, x); - -// int y = __shfl_down(wmin, i); -// wmin = packed_min(wmin, y); -// } - -// // Share with the cohort -// wblockmax = __shfl(wmax, group_leader); -// wblockmin = __shfl(wmin, group_leader); -// } - __quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) { __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index b5c21b40ddb8..16b13eb8d97b 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -95,7 +95,7 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, enum QuickReduceQuantLevel { FP16 = 0, INT8 = 1, - int6 = 2, + INT6 = 2, INT4 = 3, }; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 679795447b6c..fcd224ea8ae1 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -370,14 +370,14 @@ struct CodecQ6 : public CodecBase { nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); int32_t packed_bf16 = *reinterpret_cast(&bf2); - w[i] = pk_add(packed_bf16, kRangeMin); + w[i] = packed_add(packed_bf16, kRangeMin); } } } // Apply decoding scales for (int i = 0; i < 4; i++) { - w[i] = pk_mul(w[i], qs); + w[i] = packed_mul(w[i], qs); } // That's pretty much it... From 700d7b23307ae906f1144b9349ea0b952c65a82e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Jun 2025 04:21:26 +0000 Subject: [PATCH 18/28] fix max size err Signed-off-by: Haoyang Li --- csrc/custom_quickreduce.cu | 10 +++++++--- csrc/quickreduce/quick_reduce.h | 2 +- .../device_communicators/quick_all_reduce.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu index 97ebfbaa0a13..91b8abf1a162 100644 --- a/csrc/custom_quickreduce.cu +++ b/csrc/custom_quickreduce.cu @@ -10,6 +10,8 @@ quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (rank < 0 || rank >= world_size) @@ -20,9 +22,11 @@ quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { } void qr_destroy(quickreduce::fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - fa->destroy(); - delete fa; + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } } torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 16b13eb8d97b..0a345772bd3c 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -104,7 +104,7 @@ struct DeviceComms { static long constexpr kTileSize = 256 * 16 * 8; // Max problem size is 2GB (in bytes) or half of uint32_t max value. - static int64_t constexpr kMaxProblemSize = 2147483647; + static int64_t constexpr kMaxProblemSize = 2147483648; static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize; // Max TP-8 diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 5f0887956ec2..322633c220a4 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -107,7 +107,7 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): Performs an out-of-place all reduce. """ inp_size = inp.numel() * inp.element_size() - if inp_size >= self.max_size: + if inp_size > self.max_size: return None inp_dtype = inp.dtype From 3eea342b9d0437791f1b426031dcba64a0ed2814 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Jun 2025 08:22:17 +0000 Subject: [PATCH 19/28] adjust for comments Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce_impl.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index fcd224ea8ae1..89a07629d713 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -52,7 +52,6 @@ struct CodecFP : public CodecBase { } }; -// MARK: Q4 Line Codec // Int4 symmetric quantization codec. // We quantize the FP16 data to block-scaled Int4 in blocks of 4 * // kThreadGroupSize. @@ -207,7 +206,6 @@ struct CodecQ4 : public CodecBase { } }; -// MARK: Q6 Line Codec // Int6 symmetric quantization codec. // We quantize the FP16 data to block-scaled Int6 in blocks of 4 * // kThreadGroupSize. @@ -386,7 +384,6 @@ struct CodecQ6 : public CodecBase { } }; -// MARK: Q8 Line Codec // Int8 symmetric quantization codec. // We quantize the FP16 data to block-scaled Int8 in blocks of 4 * // kThreadGroupSize. @@ -545,7 +542,6 @@ struct CodecQ8 : public CodecBase { } }; -// MARK: Twoshot All Reduce // Twoshot All Reduce template struct AllReduceTwoshot { @@ -681,7 +677,6 @@ struct AllReduceTwoshot { } }; -// MARK: Oneshot All Reduce // Oneshot AllReduce template struct AllReduceOneshot { From d6bc3e20e86d4943709d83c606ffe5945d6aeb0d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Jun 2025 10:55:38 +0000 Subject: [PATCH 20/28] integrate_qr2cr Signed-off-by: Haoyang Li --- csrc/quickreduce/base.h | 3 - csrc/quickreduce/quick_reduce.h | 82 +---- csrc/quickreduce/quick_reduce_impl.cuh | 104 ------- tests/distributed/test_quick_reduce.py | 127 -------- .../device_communicators/cuda_communicator.py | 28 +- .../device_communicators/custom_all_reduce.py | 293 ++++++++++++++---- .../device_communicators/quick_all_reduce.py | 138 --------- vllm/envs.py | 12 +- 8 files changed, 270 insertions(+), 517 deletions(-) delete mode 100644 tests/distributed/test_quick_reduce.py delete mode 100644 vllm/distributed/device_communicators/quick_all_reduce.py diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index 8138eda82875..cf48cc770837 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -51,9 +51,6 @@ static constexpr int kWavefront = 64; // 256 thread, 4 wavefronts. static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; -static constexpr int kThreadsOneShot = 512; -static dim3 constexpr kBlockOneShot = {kThreadsOneShot, 1, 1}; - // Number of threads in a group for quantization // It corresponds to 32 F16 elements in quantization block static constexpr int kThreadGroupSize = 8; diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 0a345772bd3c..4356bd628e97 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -18,23 +18,6 @@ namespace quickreduce { using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize2 = 8192 * 12; -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize4 = 8192 * 8; -static constexpr unsigned int kOneShotAllreduceMaxElemsWorldSize8 = 8192 * 4; -static constexpr unsigned int kOneShotAllreduceMaxSize = - std::max(kOneShotAllreduceMaxElemsWorldSize2 * 2, - std::max(kOneShotAllreduceMaxElemsWorldSize4 * 4, - kOneShotAllreduceMaxElemsWorldSize8 * 8)) * - sizeof(half); - -template -__global__ __quickreduce_launch_bounds_one_shot__ static void -allreduce_prototype_oneshot(T const* A, T* B, uint32_t N, int rank, - uint8_t** dbuffer_list, uint32_t data_offset, - uint32_t flag_color) { - AllReduceKernel::run(A, B, N, rank, dbuffer_list, data_offset, flag_color); -} - template __global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, @@ -50,24 +33,6 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, } } -#define ONESHOT_DISPATCH() \ - if (world_size == 2) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } else if (world_size == 4) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } else if (world_size == 8) { \ - using AllReduceKernel = AllReduceOneshot; \ - hipLaunchKernelGGL((allreduce_prototype_oneshot), \ - dim3(grid), dim3(kBlockOneShot), 0, stream, A, B, N, \ - rank, dbuffer_list, data_offset, flag_color); \ - } - #define TWOSHOT_DISPATCH(__codec) \ if (world_size == 2) { \ using LineCodec = __codec; \ @@ -132,8 +97,7 @@ struct DeviceComms { // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); - static constexpr int64_t data_buffer_size = std::max( - 2 * kMaxProblemSize, static_cast(kOneShotAllreduceMaxSize)); + static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize; int64_t total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, @@ -203,34 +167,22 @@ struct DeviceComms { } // Configuration. - uint32_t msg_size = N * sizeof(T); - bool use_one_shot_allreduce = - (world_size == 2 and N <= kOneShotAllreduceMaxElemsWorldSize2) or - (world_size == 4 and N <= kOneShotAllreduceMaxElemsWorldSize4) or - (world_size == 8 and N <= kOneShotAllreduceMaxElemsWorldSize8); - if (use_one_shot_allreduce) { - // Each thread processes blocks out of 4 elements - uint64_t num_blocks = divceil(N, (4 * kThreadsOneShot)); - uint64_t grid = min(kMaxNumBlocks, num_blocks); - ONESHOT_DISPATCH() - } else { - uint64_t num_blocks = divceil(msg_size, kTileSize); - uint64_t grid = min(kMaxNumBlocks, num_blocks); - auto quant_level_ = static_cast(quant_level); - switch (quant_level_) { - case QuickReduceQuantLevel::INT8: - TWOSHOT_DISPATCH(CodecQ8) - break; - case QuickReduceQuantLevel::INT6: - TWOSHOT_DISPATCH(CodecQ6) - break; - case QuickReduceQuantLevel::INT4: - TWOSHOT_DISPATCH(CodecQ4) - break; - default: - TWOSHOT_DISPATCH(CodecFP) - break; - } + uint64_t num_blocks = divceil(msg_size, kTileSize); + uint64_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; } HIP_CHECK(cudaGetLastError()); // Rotate the flag color. diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 89a07629d713..92be8ab8f127 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -677,108 +677,4 @@ struct AllReduceTwoshot { } }; -// Oneshot AllReduce -template -struct AllReduceOneshot { - static_assert(sizeof(T) == 2); - - __device__ static void run( - T const* __restrict__ A, // input - T* __restrict__ B, // output - uint32_t const N, // number of elements - uint32_t const rank, // rank index - uint8_t** __restrict__ buffer_list, // communication buffers - long const data_offset, // offset to start of the data buffer - uint32_t flag_color) { - BufferResource src_buffer(const_cast(A), N * sizeof(T)); - BufferResource dst_buffer(B, N * sizeof(T)); - - uint8_t* rank_buffer = buffer_list[rank]; - - const int block_size = blockDim.x; - const int thread = threadIdx.x; - const int block = blockIdx.x; - const uint32_t problem_size = (N + 3) / 4; - - int32x4_t tA, tB; - long grid = gridDim.x; - long data_stride = grid * block_size * sizeof(int32x4_t); - long comm_flags0_offset = block * (world_size * sizeof(int)); - long comm_flags1_offset = - comm_flags0_offset + grid * (world_size * sizeof(int)); - - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - // load values - tA = buffer_load_dwordx4(src_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); - - // Write rank data into this rank segment of every rank's communication - // buffer. -#pragma unroll - for (int r = 0; r < world_size; r++) { - int32x4_t* send_buffer = reinterpret_cast( - buffer_list[r] + data_offset + rank * data_stride + - idx * sizeof(int32x4_t)); - __builtin_nontemporal_store(tA, send_buffer); - } - } - - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags0_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELEASE); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags0_offset + r * sizeof(int)); - - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_ACQUIRE) != flag_color) { - } - } - __syncthreads(); - - for (int idx = block * block_size + thread; idx < problem_size; - idx += grid * block_size) { - { - int r = 0; - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tA = __builtin_nontemporal_load(recv_buffer); - } -#pragma unroll - for (int r = 1; r < world_size; r++) { - // Read posted data from the rank's communication buffer. - int32x4_t* recv_buffer = reinterpret_cast( - rank_buffer + data_offset + r * data_stride + - idx * sizeof(int32x4_t)); - tB = __builtin_nontemporal_load(recv_buffer); - - // Reduce the local data with the read data - packed_assign_add(&tA, &tB); - } - - buffer_store_dwordx4(tA, dst_buffer.descriptor, idx * sizeof(int32x4_t), - 0, 0); - } - - __syncthreads(); - if (thread < world_size) { - int r = thread; - int* peer_flag_ptr = reinterpret_cast( - buffer_list[r] + comm_flags1_offset + rank * sizeof(int)); - __atomic_store_n(peer_flag_ptr, flag_color, __ATOMIC_RELAXED); - int* self_flag_ptr = reinterpret_cast( - rank_buffer + comm_flags1_offset + r * sizeof(int)); - - // Wait for the flags to be set. - while (__atomic_load_n(self_flag_ptr, __ATOMIC_RELAXED) != flag_color) { - } - } - } -}; - } // namespace quickreduce \ No newline at end of file diff --git a/tests/distributed/test_quick_reduce.py b/tests/distributed/test_quick_reduce.py deleted file mode 100644 index 32731de9a648..000000000000 --- a/tests/distributed/test_quick_reduce.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import random - -import pytest -import ray -import torch -import torch.distributed as dist - -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, graph_capture) -from vllm.platforms import current_platform - -from ..utils import init_test_distributed_environment, multi_process_parallel - -random.seed(42) -test_sizes = [random.randint(256 * 8 * 4, 2048 * 1024) for _ in range(8)] -for i, v in enumerate(test_sizes): - test_sizes[i] -= v % 8 - - -# Same as in custom all-reduce -# Only enable QuickReduce -@ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce( - monkeypatch: pytest.MonkeyPatch, - tp_size, - pp_size, - rank, - distributed_init_port, -): - with monkeypatch.context() as m: - m.delenv("CUDA_VISIBLE_DEVICES", raising=False) - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) - - group = get_tensor_model_parallel_group().device_group - - # A small all_reduce for warmup. - # this is needed because device communicators might be created lazily - # (e.g. NCCL). This will ensure that the communicator is initialized - # before any communication happens, so that this group can be used for - # graph capture immediately. - data = torch.zeros(1) - data = data.to(device=device) - torch.distributed.all_reduce(data, group=group) - torch.cuda.synchronize() - del data - - # we use the first group to communicate once - # and the second group to communicate twice - # and so on - # this is used to demonstrate that each group can - # communicate independently - num_communication = rank // tp_size + 1 - - for sz in test_sizes: - for dtype in [torch.float16, torch.bfloat16]: - with graph_capture(device=device) as graph_capture_context: - # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): - for i in range(num_communication): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1, group=group) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2, group=group) - graph.replay() - torch.testing.assert_close(out1, inp1) - torch.testing.assert_close(out2, inp2) - - -@ray.remote(num_gpus=1, max_calls=1) -def eager_quick_allreduce( - monkeypatch: pytest.MonkeyPatch, - tp_size, - pp_size, - rank, - distributed_init_port, -): - with monkeypatch.context() as m: - m.delenv("CUDA_VISIBLE_DEVICES", raising=False) - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) - for dtype in [torch.float16, torch.bfloat16]: - - num_communication = rank // tp_size + 1 - sz = 256 * 8 * 8 - qr_comm = get_tp_group().device_communicator.qr_comm - inp = torch.ones(sz, dtype=dtype, device=device) - out = inp - for _ in range(num_communication): - out = qr_comm.all_reduce(out) - torch.testing.assert_close(out, inp * (tp_size**num_communication)) - - -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Quick reduce is only supported on RocM.") -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) -@pytest.mark.parametrize("test_target", - [eager_quick_allreduce, graph_allreduce]) -def test_quick_reduce_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size, test_target): - world_size = tp_size * pipeline_parallel_size - if world_size > torch.cuda.device_count(): - pytest.skip("Not enough GPUs to run the test.") - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4d4ba4a8b4b6..055d91690e67 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,7 +8,6 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -42,8 +41,6 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -55,16 +52,10 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce(group=self.cpu_group, - device=self.device) - - self.qr_comm: Optional[QuickAllReduce] = None - if (use_custom_allreduce and current_platform.is_rocm() - and self.world_size > 1): - # Initialize a custom fast all-reduce implementation for AMD - # based on quick reduce (https://github.com/mk1-project/quickreduce). - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -88,15 +79,8 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.all_reduce(input_) - assert out is not None - return out - + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fcc4..40adaf891171 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +from enum import Enum from typing import Optional, Union import torch @@ -10,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -23,10 +25,24 @@ except Exception: # For CPUs custom_ar = False +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs + quick_ar = False logger = init_logger(__name__) +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: @@ -49,32 +65,58 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + _QR_SUPPORTED_WORLD_SIZES = [2, 4, 8] # max_size: max supported allreduce size def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + cr_max_size=8192 * 1024, + qr_max_size=512 * 1024 * 1024, + qr_min_size=2 * 1024 * 1024) -> None: """ + Custom allredcue (cr) is non-destructive acceleration, which is + available for cuda and rocm MI300 series. + Custom quick allreduce (qr) is accelerated by quantization, + currently supports fp16, Q8, Q6, Q4 quantization. + We view qr as complementary to cr, the condition for qr is + even more demanding; qr is initialized, then cr must also + be initialized. If the conditions of cr are not met, qr is + naturally not initialized. + Due to instruction set limitations, only rocm MI300 series + is supported for the time being. Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, it will be bind to f"cuda:{local_rank}". + cr_max_size: max supported size of cr. + qr_max_size: max supported size of qr. + qr_min_size: min supported size of qr. Less than this size, + cr is better. It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. """ + self._QR_SHOULD_INIT = True self._IS_CAPTURING = False self.disabled = True + self.cr_max_size = cr_max_size + self.qr_max_size = qr_max_size + self.qr_min_size = qr_min_size if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-GPU environment logger.info("Custom allreduce is disabled because " "of missing custom allreduce library") - return + if not quick_ar: + logger.info("Custom quick allreduce is disabled because " + "of missing quick allreduce library") + self._QR_SHOULD_INIT = False + if not quick_ar and not custom_ar: + return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -88,10 +130,12 @@ def __init__(self, return rank = dist.get_rank(group=self.group) - self.rank = rank world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size if world_size == 1: - # No need to initialize custom allreduce for single GPU case. + # No need to initialize custom allreduce or custom quick + # allreduce for single GPU case. return if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: @@ -102,6 +146,13 @@ def __init__(self, world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) return + if self._QR_SHOULD_INIT and \ + world_size not in CustomAllreduce._QR_SUPPORTED_WORLD_SIZES: + self._QR_SHOULD_INIT = False + logger.warning( + "Custom quick allreduce is disabled due to an unsupported " + "world size: %d.", world_size) + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): @@ -131,9 +182,9 @@ def __init__(self, # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - fully_connected = current_platform.is_fully_connected( + self.fully_connected = current_platform.is_fully_connected( physical_device_ids) - if world_size > 2 and not fully_connected: + if world_size > 2 and not self.fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " @@ -143,23 +194,36 @@ def __init__(self, # this is expensive to compute at the first time # then we cache the result # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not current_platform.is_rocm() and not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks " - "GPU P2P capability or P2P test failed. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.") - return - + if not current_platform.is_rocm(): + # First, we only enable custom allreduce for MI300 series, + # If it's rocm then it must be MI300 series, qr must be available. + self._QR_SHOULD_INIT = False + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True " + "explicitly.") + return self.disabled = False + self.init_custom_allreduce() + self.init_custom_quick_allreduce() + + def init_custom_allreduce(self): + """ + Initialize custom allreduce + """ # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group, + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + + self.cr_max_size, + group=self.group, uncached=True) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.buffer_ptrs = self.create_shared_buffer(self.cr_max_size, + group=self.group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -168,13 +232,62 @@ def __init__(self, self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.max_size = max_size - self.rank = rank - self.world_size = world_size - self.fully_connected = fully_connected - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.fully_connected) - ops.register_buffer(self._ptr, self.buffer_ptrs) + self.cr_max_size = self.cr_max_size + + self._cr_ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, + self.rank, self.fully_connected) + ops.register_buffer(self._cr_ptr, self.buffer_ptrs) + + def init_custom_quick_allreduce(self): + """ + Initialize a custom quick allreduce implementation for AMD + based on quick reduce (https://github.com/mk1-project/quickreduce). + """ + vllm_config = get_current_vllm_config() + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + self._QR_SHOULD_INIT = False + + # On RocM bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment is not set to 1 we convert input to fp16 + self.use_fp16_kernels: bool = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME + + if self._QR_SHOULD_INIT: + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_QR_QUANT_REGIME") + return + + self.qr_quant_level = QuickReduceRegime[regime_str] + # These numbers are based on kernel tests. + # TODO: We need the full kernel test to guide the + # size adjustment here + if self.world_size == 2: + self.qr_min_size = 1 * 1024 * 1024 + else: + self.qr_min_size = 2 * 1024 * 1024 + self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) + self.create_qr_shared_buffer() + if dtype == torch.bfloat16 and self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: due to the lack of bf16 assembly " + "instruction set, the performance gain of bf16 is " + "limited. We convert bfloat16 to float16 to speed " + "up quick allreduce. You can set " + "envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=0 to turn " + "this conversion off.") + # There is no case where qr is initialized and + # cr is not initialized @contextmanager def capture(self): @@ -192,7 +305,7 @@ def capture(self): self.register_graph_buffers() def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._cr_ptr) logger.info("Registering %d cuda graph addresses", len(offset)) # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. @@ -209,9 +322,37 @@ def register_graph_buffers(self): # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore - ops.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._cr_ptr, handles, offsets) - def should_custom_ar(self, inp: torch.Tensor): + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled and not self._QR_SHOULD_INIT: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # custom quick allreduce requires input byte size to be multiples of 16 + if inp.dtype == torch.float16: + return inp_size <= self.qr_max_size and inp_size >= self.qr_min_size + elif inp.dtype == torch.bfloat16: + if self.use_fp16_kernels: + # cast2half, so the same condition + return inp_size <= self.qr_max_size and \ + inp_size >= self.qr_min_size + else: + # TODO: check bf16 condition for mi300 + return (inp_size <= self.qr_max_size + and inp_size > 1024 * 1024 * 16 + and self.world_size == 2) + return False + + def should_custom_allreduce(self, inp: torch.Tensor): if self.disabled: return False inp_size = inp.numel() * inp.element_size() @@ -223,15 +364,20 @@ def should_custom_ar(self, inp: torch.Tensor): # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.fully_connected: - return inp_size < self.max_size + return inp_size < self.cr_max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): - """Performs an out-of-place all reduce. + def should_custom_ar(self, inp: torch.Tensor): + # Determine whether to use qr, or cr or quit + return self.should_quick_allreduce( + inp) or self.should_custom_allreduce(inp) + + def cr_all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place custom all reduce. If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered @@ -240,37 +386,69 @@ def all_reduce(self, if out is None: out = torch.empty_like(inp) if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) + ops.all_reduce(self._cr_ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce(self._cr_ptr, inp, out, self.buffer_ptrs[self.rank], + self.cr_max_size) return out + def qr_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + inp_dtype = inp.dtype + if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: + inp = inp.to(torch.float16) + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._qr_ptr, inp, out, self.qr_quant_level.value) + return out.to(inp_dtype) + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: """The main allreduce API that provides support for cuda graph.""" # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): + if self.disabled: return None - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) - else: + # try custom quick allreduce first, then custom allreduce + if self.should_quick_allreduce(input): + # We don't need the context of quick allreduce to do graph capture + # because the ipc access is already collected in init() and + # we can capture the quick allreduce directly. + if self._IS_CAPTURING and \ + not torch.cuda.is_current_stream_capturing(): # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) - else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) + else: + return self.qr_all_reduce(input) + + if self.should_custom_allreduce(input): + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.cr_all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce + # incurs a cost of cudaMemcpy, which should be small + # (<=1% of overall latency) compared to the performance + # gain of using custom kernels + return self.cr_all_reduce(input, registered=False) + + return None def close(self): - if not self.disabled and self._ptr: - if ops is not None: - ops.dispose(self._ptr) - self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs, rank=self.rank) - self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + if not self.disabled: + if self._cr_ptr: + if ops is not None: + ops.dispose(self._cr_ptr) + self._cr_ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + if self._qr_ptr: + if ops is not None: + ops.qr_destroy(self._qr_ptr) + self._qr_ptr = 0 def __del__(self): self.close() @@ -294,6 +472,17 @@ def create_shared_buffer(size_in_bytes: int, pointers.append(ops.open_mem_handle(h)) return pointers + def create_qr_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after qr_init_device_collectives + """ + handle = ops.qr_get_handle(self._qr_ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._qr_ptr, handles) + @staticmethod def free_shared_buffer(pointers: list[int], group: Optional[ProcessGroup] = None, diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py deleted file mode 100644 index 322633c220a4..000000000000 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import logging -from enum import Enum -from typing import Union - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.platforms import current_platform - -logger = logging.getLogger(__name__) - -try: - ops.qr_max_size() - ops_available = True -except Exception: - # For CPUs - ops_available = False - - -class QuickReduceRegime(Enum): - FP = 0 - INT8 = 1 - INT6 = 2 - INT4 = 3 - NONE = 4 - - -class QuickAllReduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 8] - _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] - - def __init__(self, group: ProcessGroup, - device: Union[int, str, torch.device]) -> None: - self.disabled = True - if not ops_available: - # disable because of missing quick reduce library - # e.g. in a non-cuda environment - logger.info("Custom quick allreduce is disabled because " - "of missing custom quick allreduce library") - return - - self.max_size = ops.qr_max_size() - self.group = group - regime_str = envs.VLLM_ROCM_CA_QUANT_REGIME - assert regime_str in QuickReduceRegime.__members__, ( - f"Invalid quantization level: {regime_str}. " - "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}") - if regime_str == "NONE": - logger.debug("Custom quick allreduce is disabled based " - "on env variable VLLM_ROCM_CA_QUANT_REGIME") - return - self.quant_level = QuickReduceRegime[regime_str] - # On RocM bfloat16 kernels are slower than fp16 - # due to slower match operations - # If environment is not set to 1 we convert input to fp16 - self.use_fp16_kernels = envs.VLLM_ROCM_CA_CAST_BF16_TO_FP16 - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "QuickReduce should be attached to a non-NCCL group.") - rank = dist.get_rank(group=self.group) - world_size = dist.get_world_size(group=self.group) - if world_size == 1: - # No need to initialize QuickReduce for single GPU case. - return - - if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: - logger.warning( - "QuickReduce is disabled due to an unsupported world" - " size: %d. Supported world sizes: %s." - " To disable this warning set VLLM_ROCM_CA_BACKEND" - " to None", world_size, - str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) - return - - assert current_platform.is_rocm(), ( - "QuickReduce is only supported on ROCm platform.") - if isinstance(device, int): - device = torch.device(f"cuda:{device}") - elif isinstance(device, str): - device = torch.device(device) - # now `device` is a `torch.device` object - assert isinstance(device, torch.device) - self.device = device - torch.cuda.set_device(self.device) - - self._ptr = ops.init_custom_qr(rank, world_size) - self.create_shared_buffer() - self.disabled = False - - def create_shared_buffer(self): - """ - Creates a shared buffer for quickreduce. - Has to be called after qr_init_device_collectives - """ - handle = ops.qr_get_handle(self._ptr) - world_size = dist.get_world_size(group=self.group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=self.group) - ops.qr_open_handles(self._ptr, handles) - - def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): - """ - Performs an out-of-place all reduce. - """ - inp_size = inp.numel() * inp.element_size() - if inp_size > self.max_size: - return None - - inp_dtype = inp.dtype - if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: - inp = inp.to(torch.float16) - if out is None: - out = torch.empty_like(inp) - - ops.qr_all_reduce(self._ptr, inp, out, self.quant_level.value) - return out.to(inp_dtype) - - def close(self): - if not self.disabled and getattr(self, "_ptr", None): - ops.qr_destroy(self._ptr) - self._ptr = 0 - - def __del__(self): - self.close() - - def should_quick_allreduce(self, inp: torch.Tensor): - if self.disabled: - return False - inp_size = inp.numel() * inp.element_size() - # QuickReduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - return inp.dtype in QuickAllReduce._SUPPORTED_DTYPES and \ - inp_size < self.max_size diff --git a/vllm/envs.py b/vllm/envs.py index 07c76ecf9063..0094df580fa7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,8 +129,8 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None - VLLM_ROCM_CA_QUANT_REGIME: str = "FP" - VLLM_ROCM_CA_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QR_QUANT_REGIME: str = "FP" + VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = True def get_default_cache_root(): @@ -676,15 +676,15 @@ def get_vllm_port() -> Optional[int]: # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce - "VLLM_ROCM_CA_QUANT_REGIME": - lambda: os.getenv("VLLM_ROCM_CA_QUANT_REGIME", "FP").upper(), + "VLLM_ROCM_QR_QUANT_REGIME": + lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "FP").upper(), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, # If environment is not set to 1, we convert input to fp16 - "VLLM_ROCM_CA_CAST_BF16_TO_FP16": - lambda: (os.getenv("VLLM_ROCM_CA_CAST_BF16_TO_FP16", "True").lower() in + "VLLM_ROCM_QR_CAST_BF16_TO_FP16": + lambda: (os.getenv("VLLM_ROCM_QR_CAST_BF16_TO_FP16", "True").lower() in ("true", "1")), # If set, when running in Quark emulation mode, do not dequantize the From c50415e39d8c0227c629f10a6fadd84c45367eb8 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Jun 2025 11:12:28 +0000 Subject: [PATCH 21/28] fix message size Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce.h | 1 + vllm/distributed/device_communicators/cuda_communicator.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 4356bd628e97..1e674f07f0e0 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -167,6 +167,7 @@ struct DeviceComms { } // Configuration. + uint32_t msg_size = N * sizeof(T); uint64_t num_blocks = divceil(msg_size, kTileSize); uint64_t grid = min(kMaxNumBlocks, num_blocks); auto quant_level_ = static_cast(quant_level); diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 055d91690e67..5b724b33d2ec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -173,4 +173,4 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states + return hidden_states \ No newline at end of file From b7cc31c74db30f6e84efe01302b4f603909876f9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 16 Jun 2025 21:44:48 +0000 Subject: [PATCH 22/28] Fix fp 2GB bug Add min sizes for QR Cleanup Signed-off-by: ilmarkov --- csrc/quickreduce/base.h | 1 + csrc/quickreduce/quick_reduce.h | 79 ++++++++++-- csrc/quickreduce/quick_reduce_impl.cuh | 12 +- .../device_communicators/custom_all_reduce.py | 115 +++++++----------- vllm/envs.py | 6 +- 5 files changed, 122 insertions(+), 91 deletions(-) diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h index cf48cc770837..c670ff0fdec0 100644 --- a/csrc/quickreduce/base.h +++ b/csrc/quickreduce/base.h @@ -40,6 +40,7 @@ static constexpr int kAtomStride = kBlockSize; // Size and atom stride of source/destination data that the block will // process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); // Max number of blocks. 304 CUs on MI300 diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 1e674f07f0e0..e5371db19ae0 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -20,16 +20,17 @@ static_assert(sizeof(void*) == sizeof(fptr_t)); template __global__ __quickreduce_launch_bounds_two_shot__ static void -allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, +allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, int rank, uint8_t** dbuffer_list, uint32_t data_offset, uint32_t flag_color) { int block = blockIdx.x; int grid = gridDim.x; while (block < num_blocks) { - AllReduceKernel::run(A, B, N, block, num_blocks, rank, dbuffer_list, - data_offset, flag_color); + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, + flag_color); block += grid; + flag_color++; } } @@ -58,16 +59,13 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, int num_blocks, } enum QuickReduceQuantLevel { - FP16 = 0, + F16 = 0, INT8 = 1, INT6 = 2, INT4 = 3, }; struct DeviceComms { - // Workgroup scope = Tile = (256 threads x 16B x 8 atoms) - static long constexpr kTileSize = 256 * 16 * 8; - // Max problem size is 2GB (in bytes) or half of uint32_t max value. static int64_t constexpr kMaxProblemSize = 2147483648; static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize; @@ -95,8 +93,9 @@ struct DeviceComms { this->world_size = world_size; this->rank = rank; - // Allocate buffer size for worst case: Twoshot FP16 2-stage buffer. - uint32_t flags_buffer_size = 2 * world_size * kMaxTiles * sizeof(int); + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = + 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize; int64_t total_buffer_size = flags_buffer_size + data_buffer_size; data_offset = flags_buffer_size; @@ -136,6 +135,59 @@ struct DeviceComms { } } + template + bool use_fp_kernel(QuickReduceQuantLevel quant_level, uint32_t msg_size) { + if constexpr (std::is_same::value) { + if (world_size == 2) { + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 512 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 512 * 1024); + } else if (world_size == 4) { + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 8 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 2 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 1024 * 1024); + } else if (world_size == 8) { + // TODO need to do kernel benchmarking for TP8 + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 512 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 512 * 1024); + } + } else { + if (world_size == 2) { + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 4 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 4 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 4 * 1024 * 1024); + } else if (world_size == 4) { + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 32 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 32 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 32 * 1024 * 1024); + } else if (world_size == 8) { + // TODO need to do kernel benchmarking for TP8 + return (quant_level == QuickReduceQuantLevel::INT8 and + msg_size < 32 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT6 and + msg_size < 32 * 1024 * 1024) or + (quant_level == QuickReduceQuantLevel::INT4 and + msg_size < 32 * 1024 * 1024); + } + } + } + void open_ipc_handles(std::vector const& ipc_handles) { assert(ipc_handles.size() == all_buffer_ipc_handles.size()); for (int i = 0; i < world_size; i++) { @@ -168,9 +220,12 @@ struct DeviceComms { // Configuration. uint32_t msg_size = N * sizeof(T); - uint64_t num_blocks = divceil(msg_size, kTileSize); - uint64_t grid = min(kMaxNumBlocks, num_blocks); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); auto quant_level_ = static_cast(quant_level); + if (use_fp_kernel(quant_level_, msg_size)) { + quant_level_ = QuickReduceQuantLevel::F16; + } switch (quant_level_) { case QuickReduceQuantLevel::INT8: TWOSHOT_DISPATCH(CodecQ8) @@ -187,7 +242,7 @@ struct DeviceComms { } HIP_CHECK(cudaGetLastError()); // Rotate the flag color. - flag_color++; + flag_color += divceil(N, grid); } }; diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 92be8ab8f127..247aa1e4511c 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -553,7 +553,6 @@ struct AllReduceTwoshot { T const* __restrict__ input, T* __restrict__ output, uint32_t const N, // number of elements int const block, // block index - int const num_blocks, // number of blocks int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers uint32_t const data_offset, // offset to start of the data buffer @@ -562,7 +561,8 @@ struct AllReduceTwoshot { int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; Codec codec(thread, rank); - + int block_id = blockIdx.x; + int grid_size = gridDim.x; // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; @@ -580,13 +580,13 @@ struct AllReduceTwoshot { // Phase-1A: Write segment data into the communication buffer of the target // rank responsible for this segment. uint32_t comm_data0_offset = - data_offset + block * Codec::kTransmittedTileSize; + data_offset + block_id * Codec::kTransmittedTileSize; uint32_t comm_data1_offset = - num_blocks * Codec::kTransmittedTileSize + comm_data0_offset; + grid_size * Codec::kTransmittedTileSize + comm_data0_offset; - uint32_t comm_flags0_offset = block * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); uint32_t comm_flags1_offset = - num_blocks * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { int32x4_t* send_buffer = diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 40adaf891171..601caedc3a81 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,7 +11,6 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.config import get_current_vllm_config from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -66,14 +65,23 @@ class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _QR_SUPPORTED_WORLD_SIZES = [2, 4, 8] + _QR_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + + # TODO need to do kernel benchmarking for TP8 to get min sizes + _QR_MIN_SIZE = { + (torch.float16, 2): 512 * 1024, # 512KB + (torch.float16, 4): 512 * 1024, # 512KB + (torch.float16, 8): 512 * 1024, # 512KB + (torch.bfloat16, 2): 1024 * 1024, # 1MB + (torch.bfloat16, 4): 4 * 1024 * 1024, # 4MB + (torch.bfloat16, 8): 4 * 1024 * 1024, # 4KB + } # max_size: max supported allreduce size def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - cr_max_size=8192 * 1024, - qr_max_size=512 * 1024 * 1024, - qr_min_size=2 * 1024 * 1024) -> None: + cr_max_size=8192 * 1024) -> None: """ Custom allredcue (cr) is non-destructive acceleration, which is available for cuda and rocm MI300 series. @@ -91,9 +99,6 @@ def __init__(self, device: the device to bind the CustomAllreduce to. If None, it will be bind to f"cuda:{local_rank}". cr_max_size: max supported size of cr. - qr_max_size: max supported size of qr. - qr_min_size: min supported size of qr. Less than this size, - cr is better. It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -101,9 +106,9 @@ def __init__(self, self._QR_SHOULD_INIT = True self._IS_CAPTURING = False self.disabled = True + self.qr_disabled = True self.cr_max_size = cr_max_size - self.qr_max_size = qr_max_size - self.qr_min_size = qr_min_size + self.qr_max_size = ops.qr_max_size() if not custom_ar: # disable because of missing custom allreduce library @@ -190,14 +195,15 @@ def __init__(self, " more than two PCIe-only GPUs. To silence this warning, " "specify disable_custom_all_reduce=True explicitly.") return - # test P2P capability, this checks software/cudaruntime support - # this is expensive to compute at the first time - # then we cache the result - # On AMD GPU, p2p is always enabled between XGMI connected GPUs if not current_platform.is_rocm(): - # First, we only enable custom allreduce for MI300 series, + # First, we only enable quickreduce for MI300 series, # If it's rocm then it must be MI300 series, qr must be available. self._QR_SHOULD_INIT = False + + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs if not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " @@ -205,8 +211,10 @@ def __init__(self, "warning, specify disable_custom_all_reduce=True " "explicitly.") return - self.disabled = False + self.init_custom_allreduce() + self.disabled = False + self.init_custom_quick_allreduce() def init_custom_allreduce(self): @@ -243,17 +251,11 @@ def init_custom_quick_allreduce(self): Initialize a custom quick allreduce implementation for AMD based on quick reduce (https://github.com/mk1-project/quickreduce). """ - vllm_config = get_current_vllm_config() - dtype = vllm_config.model_config.dtype - if dtype not in [torch.float16, torch.bfloat16]: - self._QR_SHOULD_INIT = False - # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 self.use_fp16_kernels: bool = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME - if self._QR_SHOULD_INIT: if regime_str not in QuickReduceRegime.__members__: logger.warning( @@ -272,22 +274,9 @@ def init_custom_quick_allreduce(self): # These numbers are based on kernel tests. # TODO: We need the full kernel test to guide the # size adjustment here - if self.world_size == 2: - self.qr_min_size = 1 * 1024 * 1024 - else: - self.qr_min_size = 2 * 1024 * 1024 self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) self.create_qr_shared_buffer() - if dtype == torch.bfloat16 and self.use_fp16_kernels: - logger.info( - "Custom quick allreduce: due to the lack of bf16 assembly " - "instruction set, the performance gain of bf16 is " - "limited. We convert bfloat16 to float16 to speed " - "up quick allreduce. You can set " - "envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=0 to turn " - "this conversion off.") - # There is no case where qr is initialized and - # cr is not initialized + self.qr_disabled = False @contextmanager def capture(self): @@ -328,29 +317,23 @@ def should_quick_allreduce(self, inp: torch.Tensor): """ Check if quickreduce is available """ - if self.disabled and not self._QR_SHOULD_INIT: + if self.qr_disabled: return False inp_size = inp.numel() * inp.element_size() # custom quick allreduce requires input byte size to be # multiples of 16 + if inp.dtype not in self._QR_SUPPORTED_DTYPES: + return False if inp_size % 16 != 0: return False if not is_weak_contiguous(inp): return False # custom quick allreduce requires input byte size to be multiples of 16 - if inp.dtype == torch.float16: - return inp_size <= self.qr_max_size and inp_size >= self.qr_min_size - elif inp.dtype == torch.bfloat16: - if self.use_fp16_kernels: - # cast2half, so the same condition - return inp_size <= self.qr_max_size and \ - inp_size >= self.qr_min_size - else: - # TODO: check bf16 condition for mi300 - return (inp_size <= self.qr_max_size - and inp_size > 1024 * 1024 * 16 - and self.world_size == 2) - return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return inp_size > self._QR_MIN_SIZE[(dtype, self.world_size)] and \ + inp_size <= self.qr_max_size def should_custom_allreduce(self, inp: torch.Tensor): if self.disabled: @@ -404,21 +387,12 @@ def qr_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. - if self.disabled: - return None # try custom quick allreduce first, then custom allreduce if self.should_quick_allreduce(input): # We don't need the context of quick allreduce to do graph capture # because the ipc access is already collected in init() and # we can capture the quick allreduce directly. - if self._IS_CAPTURING and \ - not torch.cuda.is_current_stream_capturing(): - # If warm up, mimic the allocation pattern since custom - # allreduce is out-of-place. - return torch.empty_like(input) - else: - return self.qr_all_reduce(input) + return self.qr_all_reduce(input) if self.should_custom_allreduce(input): if self._IS_CAPTURING: @@ -438,17 +412,18 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: return None def close(self): - if not self.disabled: - if self._cr_ptr: - if ops is not None: - ops.dispose(self._cr_ptr) - self._cr_ptr = 0 - self.free_shared_buffer(self.meta_ptrs, rank=self.rank) - self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) - if self._qr_ptr: - if ops is not None: - ops.qr_destroy(self._qr_ptr) - self._qr_ptr = 0 + if not self.cr_disabled and self._cr_ptr: + if ops is not None: + ops.dispose(self._cr_ptr) + self._cr_ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + self.cr_disabled = True + if not self.qr_disabled and self._qr_ptr: + if ops is not None: + ops.qr_destroy(self._qr_ptr) + self._qr_ptr = 0 + self.qr_disabled = True def __del__(self): self.close() diff --git a/vllm/envs.py b/vllm/envs.py index 0094df580fa7..816322a9bf34 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -130,7 +130,7 @@ VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_ROCM_QR_QUANT_REGIME: str = "FP" - VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = False def get_default_cache_root(): @@ -682,9 +682,9 @@ def get_vllm_port() -> Optional[int]: # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, - # If environment is not set to 1, we convert input to fp16 + # If environment variable is not set to 1, we convert input to fp16 "VLLM_ROCM_QR_CAST_BF16_TO_FP16": - lambda: (os.getenv("VLLM_ROCM_QR_CAST_BF16_TO_FP16", "True").lower() in + lambda: (os.getenv("VLLM_ROCM_QR_CAST_BF16_TO_FP16", "False").lower() in ("true", "1")), # If set, when running in Quark emulation mode, do not dequantize the From 214eeefa3165fd651ccf777cd82c1ba519d5817c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 17 Jun 2025 10:11:17 +0000 Subject: [PATCH 23/28] adjust condition Signed-off-by: Haoyang Li --- csrc/quickreduce/quick_reduce.h | 56 ------------------- csrc/quickreduce/quick_reduce_impl.cuh | 4 +- .../device_communicators/custom_all_reduce.py | 46 +++++++++------ 3 files changed, 31 insertions(+), 75 deletions(-) diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index e5371db19ae0..4f6e1c13978f 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -135,59 +135,6 @@ struct DeviceComms { } } - template - bool use_fp_kernel(QuickReduceQuantLevel quant_level, uint32_t msg_size) { - if constexpr (std::is_same::value) { - if (world_size == 2) { - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 512 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 512 * 1024); - } else if (world_size == 4) { - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 8 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 2 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 1024 * 1024); - } else if (world_size == 8) { - // TODO need to do kernel benchmarking for TP8 - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 512 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 512 * 1024); - } - } else { - if (world_size == 2) { - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 4 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 4 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 4 * 1024 * 1024); - } else if (world_size == 4) { - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 32 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 32 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 32 * 1024 * 1024); - } else if (world_size == 8) { - // TODO need to do kernel benchmarking for TP8 - return (quant_level == QuickReduceQuantLevel::INT8 and - msg_size < 32 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT6 and - msg_size < 32 * 1024 * 1024) or - (quant_level == QuickReduceQuantLevel::INT4 and - msg_size < 32 * 1024 * 1024); - } - } - } - void open_ipc_handles(std::vector const& ipc_handles) { assert(ipc_handles.size() == all_buffer_ipc_handles.size()); for (int i = 0; i < world_size; i++) { @@ -223,9 +170,6 @@ struct DeviceComms { uint32_t num_blocks = divceil(msg_size, kTileSize); uint32_t grid = min(kMaxNumBlocks, num_blocks); auto quant_level_ = static_cast(quant_level); - if (use_fp_kernel(quant_level_, msg_size)) { - quant_level_ = QuickReduceQuantLevel::F16; - } switch (quant_level_) { case QuickReduceQuantLevel::INT8: TWOSHOT_DISPATCH(CodecQ8) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 247aa1e4511c..176b270cc852 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -253,7 +253,9 @@ struct CodecQ6 : public CodecBase { static constexpr int kRangeBias = 0x00200020; __quickreduce_device_inline__ CodecQ6(int thread, int rank) - : CodecBase(thread, rank) {} + : CodecBase(thread, rank) { + set_fp16_ovfl(true); + } __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 601caedc3a81..c191ae4266bf 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -67,14 +68,15 @@ class CustomAllreduce: _QR_SUPPORTED_WORLD_SIZES = [2, 4, 8] _QR_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] - # TODO need to do kernel benchmarking for TP8 to get min sizes + # TODO: We should set a reasonable range for FP. + MB = 1024 * 1024 _QR_MIN_SIZE = { - (torch.float16, 2): 512 * 1024, # 512KB - (torch.float16, 4): 512 * 1024, # 512KB - (torch.float16, 8): 512 * 1024, # 512KB - (torch.bfloat16, 2): 1024 * 1024, # 1MB - (torch.bfloat16, 4): 4 * 1024 * 1024, # 4MB - (torch.bfloat16, 8): 4 * 1024 * 1024, # 4KB + (torch.float16, 2): [16 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [16 * MB, 64 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [16 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [16 * MB, 128 * MB, 128 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], } # max_size: max supported allreduce size @@ -197,7 +199,8 @@ def __init__(self, return if not current_platform.is_rocm(): # First, we only enable quickreduce for MI300 series, - # If it's rocm then it must be MI300 series, qr must be available. + # If it's rocm then it must be MI300 series because cr is only + # available on the mi300 series, qr must be available. self._QR_SHOULD_INIT = False # test P2P capability, this checks software/cudaruntime support @@ -213,8 +216,10 @@ def __init__(self, return self.init_custom_allreduce() + # self.disabled is used to indicate cr, if the condition + # of cr is not satisfied, qr must not be satisfied, + # This boolean serves as a uniform identifier for external. self.disabled = False - self.init_custom_quick_allreduce() def init_custom_allreduce(self): @@ -251,6 +256,10 @@ def init_custom_quick_allreduce(self): Initialize a custom quick allreduce implementation for AMD based on quick reduce (https://github.com/mk1-project/quickreduce). """ + vllm_config = get_current_vllm_config() + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + self._QR_SHOULD_INIT = False # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 @@ -271,11 +280,13 @@ def init_custom_quick_allreduce(self): return self.qr_quant_level = QuickReduceRegime[regime_str] - # These numbers are based on kernel tests. - # TODO: We need the full kernel test to guide the - # size adjustment here self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) self.create_qr_shared_buffer() + if dtype == torch.bfloat16 and not self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: converting bf16 to fp16 " + "can speed up qr, " + "set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on.") self.qr_disabled = False @contextmanager @@ -322,8 +333,6 @@ def should_quick_allreduce(self, inp: torch.Tensor): inp_size = inp.numel() * inp.element_size() # custom quick allreduce requires input byte size to be # multiples of 16 - if inp.dtype not in self._QR_SUPPORTED_DTYPES: - return False if inp_size % 16 != 0: return False if not is_weak_contiguous(inp): @@ -332,8 +341,9 @@ def should_quick_allreduce(self, inp: torch.Tensor): dtype = inp.dtype if self.use_fp16_kernels: dtype = torch.float16 - return inp_size > self._QR_MIN_SIZE[(dtype, self.world_size)] and \ - inp_size <= self.qr_max_size + return inp_size <= self.qr_max_size and \ + inp_size > self._QR_MIN_SIZE[(dtype, self.world_size)]\ + [self.qr_quant_level.value] def should_custom_allreduce(self, inp: torch.Tensor): if self.disabled: @@ -412,13 +422,13 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: return None def close(self): - if not self.cr_disabled and self._cr_ptr: + if not self.disabled and self._cr_ptr: if ops is not None: ops.dispose(self._cr_ptr) self._cr_ptr = 0 self.free_shared_buffer(self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) - self.cr_disabled = True + self.disabled = True if not self.qr_disabled and self._qr_ptr: if ops is not None: ops.qr_destroy(self._qr_ptr) From 4e2cfc41627e0758343d7263c6b30fef1a4ca3b3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 17 Jun 2025 10:36:26 +0000 Subject: [PATCH 24/28] fix vll_config Signed-off-by: Haoyang Li --- .../device_communicators/cuda_communicator.py | 2 +- .../device_communicators/custom_all_reduce.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 5b724b33d2ec..055d91690e67 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -173,4 +173,4 @@ def dispatch( def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c191ae4266bf..1b70cce70400 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -257,9 +257,11 @@ def init_custom_quick_allreduce(self): based on quick reduce (https://github.com/mk1-project/quickreduce). """ vllm_config = get_current_vllm_config() - dtype = vllm_config.model_config.dtype - if dtype not in [torch.float16, torch.bfloat16]: - self._QR_SHOULD_INIT = False + # for test mode + if vllm_config is not None and hasattr(vllm_config, "model_config"): + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + self._QR_SHOULD_INIT = False # On RocM bfloat16 kernels are slower than fp16 # due to slower match operations # If environment is not set to 1 we convert input to fp16 @@ -330,6 +332,8 @@ def should_quick_allreduce(self, inp: torch.Tensor): """ if self.qr_disabled: return False + if inp.dtype not in self._QR_SUPPORTED_DTYPES: + return False inp_size = inp.numel() * inp.element_size() # custom quick allreduce requires input byte size to be # multiples of 16 From f8bf2e9166f90fe9048b93f6461346840194d34f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 17 Jun 2025 10:41:28 +0000 Subject: [PATCH 25/28] change comment Signed-off-by: Haoyang Li --- vllm/distributed/device_communicators/custom_all_reduce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 1b70cce70400..00680da73f9d 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -286,8 +286,8 @@ def init_custom_quick_allreduce(self): self.create_qr_shared_buffer() if dtype == torch.bfloat16 and not self.use_fp16_kernels: logger.info( - "Custom quick allreduce: converting bf16 to fp16 " - "can speed up qr, " + "Custom quick allreduce: converting bf16 inputs to " + "fp16 can improve performance" "set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on.") self.qr_disabled = False From 247348eae35ef5c452fa51b7a50dce694c92cb1d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 17 Jun 2025 13:25:26 +0000 Subject: [PATCH 26/28] Update test. Disable QR by default. Set fp16 ovfl flag. Signed-off-by: ilmarkov --- csrc/quickreduce/quick_reduce_impl.cuh | 16 ++--- tests/distributed/test_custom_all_reduce.py | 41 +++++++++-- .../device_communicators/custom_all_reduce.py | 69 ++++++++++--------- vllm/envs.py | 4 +- 4 files changed, 81 insertions(+), 49 deletions(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 176b270cc852..f4d8eebefd79 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -12,7 +12,9 @@ struct CodecBase { __quickreduce_device_inline__ CodecBase(int thread, int rank) : thread(thread), rank(rank), - group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {} + group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } }; // Default full precision codec. @@ -98,9 +100,7 @@ struct CodecQ4 : public CodecBase { static constexpr int kRangeBias = 0x00080008; __quickreduce_device_inline__ CodecQ4(int thread, int rank) - : CodecBase(thread, rank) { - set_fp16_ovfl(true); - } + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { @@ -253,9 +253,7 @@ struct CodecQ6 : public CodecBase { static constexpr int kRangeBias = 0x00200020; __quickreduce_device_inline__ CodecQ6(int thread, int rank) - : CodecBase(thread, rank) { - set_fp16_ovfl(true); - } + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { @@ -431,9 +429,7 @@ struct CodecQ8 : public CodecBase { static constexpr int kRangeBias = 0x00800080; __quickreduce_device_inline__ CodecQ8(int thread, int rank) - : CodecBase(thread, rank) { - set_fp16_ovfl(true); - } + : CodecBase(thread, rank) {} __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) { diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index fae49c41d5f8..a96027005262 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import random import pytest @@ -86,7 +87,7 @@ def graph_allreduce( @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce( +def eager_custom_allreduce( monkeypatch: pytest.MonkeyPatch, tp_size, pp_size, @@ -111,19 +112,51 @@ def eager_allreduce( inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.ca_all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.ca_all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) +@ray.remote(num_gpus=1, max_calls=1) +def eager_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + os.environ["VLLM_ROCM_QR_QUANT_REGIME"] = "FP" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + sz = 1024 * 1024 + fa = get_tp_group().device_communicator.ca_comm + inp = torch.ones(sz, dtype=torch.float16, device=device) + out = inp + out = fa.qr_all_reduce(out) + torch.testing.assert_close(out, inp * tp_size) + + sz = 1024 * 1024 + inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) + out = inp + out = fa.qr_all_reduce(out) + torch.testing.assert_close(out, inp * tp_size) + + @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) -@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) +@pytest.mark.parametrize( + "test_target", + [eager_custom_allreduce, graph_allreduce, eager_quickreduce]) def test_custom_allreduce( monkeypatch: pytest.MonkeyPatch, tp_size, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 00680da73f9d..da9bcc525e90 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -71,11 +71,11 @@ class CustomAllreduce: # TODO: We should set a reasonable range for FP. MB = 1024 * 1024 _QR_MIN_SIZE = { - (torch.float16, 2): [16 * MB, 2 * MB, 2 * MB, 1 * MB], - (torch.float16, 4): [16 * MB, 64 * MB, 4 * MB, 2 * MB], + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 64 * MB, 4 * MB, 2 * MB], (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], - (torch.bfloat16, 2): [16 * MB, 8 * MB, 8 * MB, 8 * MB], - (torch.bfloat16, 4): [16 * MB, 128 * MB, 128 * MB, 16 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 128 * MB, 128 * MB, 16 * MB], (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], } @@ -256,40 +256,43 @@ def init_custom_quick_allreduce(self): Initialize a custom quick allreduce implementation for AMD based on quick reduce (https://github.com/mk1-project/quickreduce). """ + if not self._QR_SHOULD_INIT: + return + self.use_fp16_kernels = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_QR_QUANT_REGIME") + return + vllm_config = get_current_vllm_config() - # for test mode - if vllm_config is not None and hasattr(vllm_config, "model_config"): + if vllm_config is not None and \ + hasattr(vllm_config, "model_config") and \ + hasattr(vllm_config.model_config, "dtype"): dtype = vllm_config.model_config.dtype if dtype not in [torch.float16, torch.bfloat16]: self._QR_SHOULD_INIT = False - # On RocM bfloat16 kernels are slower than fp16 - # due to slower match operations - # If environment is not set to 1 we convert input to fp16 - self.use_fp16_kernels: bool = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 - regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME - if self._QR_SHOULD_INIT: - if regime_str not in QuickReduceRegime.__members__: - logger.warning( - "Custom quick allreduce:", - f"Invalid quantization level: {regime_str}. " - "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}") - return - - if regime_str == "NONE": - logger.debug("Custom quick allreduce is disabled based " - "on env variable VLLM_ROCM_QR_QUANT_REGIME") - return - - self.qr_quant_level = QuickReduceRegime[regime_str] - self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) - self.create_qr_shared_buffer() + # On RocM bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is not set to 1 we convert input to fp16 if dtype == torch.bfloat16 and not self.use_fp16_kernels: logger.info( "Custom quick allreduce: converting bf16 inputs to " "fp16 can improve performance" "set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on.") - self.qr_disabled = False + + self.qr_quant_level = QuickReduceRegime[regime_str] + self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) + self.create_qr_shared_buffer() + self.qr_disabled = False @contextmanager def capture(self): @@ -346,7 +349,7 @@ def should_quick_allreduce(self, inp: torch.Tensor): if self.use_fp16_kernels: dtype = torch.float16 return inp_size <= self.qr_max_size and \ - inp_size > self._QR_MIN_SIZE[(dtype, self.world_size)]\ + inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ [self.qr_quant_level.value] def should_custom_allreduce(self, inp: torch.Tensor): @@ -369,7 +372,7 @@ def should_custom_ar(self, inp: torch.Tensor): return self.should_quick_allreduce( inp) or self.should_custom_allreduce(inp) - def cr_all_reduce(self, + def ca_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None, @@ -411,7 +414,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: if self.should_custom_allreduce(input): if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.cr_all_reduce(input, registered=True) + return self.ca_all_reduce(input, registered=True) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. @@ -421,7 +424,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # incurs a cost of cudaMemcpy, which should be small # (<=1% of overall latency) compared to the performance # gain of using custom kernels - return self.cr_all_reduce(input, registered=False) + return self.ca_all_reduce(input, registered=False) return None diff --git a/vllm/envs.py b/vllm/envs.py index 816322a9bf34..d9515e5def09 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,7 +129,7 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None - VLLM_ROCM_QR_QUANT_REGIME: str = "FP" + VLLM_ROCM_QR_QUANT_REGIME: str = "NONE" VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = False @@ -677,7 +677,7 @@ def get_vllm_port() -> Optional[int]: # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce "VLLM_ROCM_QR_QUANT_REGIME": - lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "FP").upper(), + lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "NONE").upper(), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 From 75d3df16c3f4cf5398e04bce9d559ea2f16b7a76 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 17 Jun 2025 14:54:18 +0000 Subject: [PATCH 27/28] Fix CodecQ4 Signed-off-by: ilmarkov --- csrc/quickreduce/quick_reduce_impl.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index f4d8eebefd79..f2a13842c4c0 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -182,7 +182,7 @@ struct CodecQ4 : public CodecBase { for (int i = 0; i < 4; i++) { if constexpr (std::is_same::value) { int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; - packed_add(w[i], kHalf2_1032); + w[i] = packed_add(w[i], kHalf2_1032); } else { int32_t int16_2 = (qw >> (i * 4)) & kMask000F; int16_t low = static_cast(int16_2 & 0xFFFF); From f314fe4d76e2274012bac5cd53cf3b57b5a6e985 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 17 Jun 2025 16:08:26 +0000 Subject: [PATCH 28/28] Update min sizes Signed-off-by: ilmarkov --- vllm/distributed/device_communicators/custom_all_reduce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index da9bcc525e90..1b8907bcd301 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -72,10 +72,10 @@ class CustomAllreduce: MB = 1024 * 1024 _QR_MIN_SIZE = { (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], - (torch.float16, 4): [1 * MB, 64 * MB, 4 * MB, 2 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], - (torch.bfloat16, 4): [8 * MB, 128 * MB, 128 * MB, 16 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], }