diff --git a/CMakeLists.txt b/CMakeLists.txt index d75f0d321247..64092f004f78 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -638,6 +638,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu new file mode 100644 index 000000000000..91b8abf1a162 --- /dev/null +++ b/csrc/custom_quickreduce.cu @@ -0,0 +1,86 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + + #include "quickreduce/quick_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto device_index = c10::cuda::current_device(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, + const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, + torch::Tensor& out, int64_t quant_level) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), quickreduce::DeviceComms::kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel(), + quant_level, stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + return static_cast(quickreduce::DeviceComms::kMaxProblemSize); +} + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac19..85f86f9722d1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -360,3 +360,13 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t init_custom_qr(int64_t rank, int64_t world_size); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t quant_level); +int64_t qr_max_size(); +#endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h new file mode 100644 index 000000000000..c670ff0fdec0 --- /dev/null +++ b/csrc/quickreduce/base.h @@ -0,0 +1,344 @@ +#pragma once + +#include +#include +#include +#include + +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) + +namespace quickreduce { + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 + #define MUBUF_ACQUIRE 16 + #define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit + #define MUBUF_ACQUIRE 1 + #define MUBUF_RELEASE 0 +#endif + +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); + +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + +// Standard CDNA wavefront size. +static constexpr int kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + +// Methods +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, + unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __quickreduce_device_inline__ constexpr BufferResource() + : config(0x00020000U) {} + + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__quickreduce_device_inline__ static void buffer_store_dwordx4( + int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { + // short size = 0b00001; // Specifies the bit size to modify + // const short offset = 0b10111; // Corrected offset to 23, which is the bit + // position of FP16_OVFL const short hwRegId = 0b000001; // HW register ID for + // MODE const short simm16 = (size << 11) | (offset << 6) | hwRegId; simm16 = + // 0xdc1 + +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +template +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B); + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[0]) + : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[1]) + : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[2]) + : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[3]) + : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__quickreduce_device_inline__ void packed_assign_add( + int32x4_t* A, int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__quickreduce_device_inline__ int packed_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_min(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_abs_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = + __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = + __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; +} + +template +__quickreduce_device_inline__ int packed_add(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__quickreduce_device_inline__ int packed_sub(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(result) + : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_mul(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__quickreduce_device_inline__ int packed_rcp(int a); + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +template +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h new file mode 100644 index 000000000000..4f6e1c13978f --- /dev/null +++ b/csrc/quickreduce/quick_reduce.h @@ -0,0 +1,193 @@ +#pragma once + +#include +#include +#include "quick_reduce_impl.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \ + hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void +allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, + int rank, uint8_t** dbuffer_list, + uint32_t data_offset, uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, + flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + static int64_t constexpr kMaxProblemSize = 2147483648; + static int64_t constexpr kMaxTiles = kMaxProblemSize / kTileSize; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { destroy(); } + + void init(int world_size, int rank) { + destroy(); + this->world_size = world_size; + this->rank = rank; + + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = + 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static constexpr int64_t data_buffer_size = 2 * kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { return world_size; } + int get_rank() { return rank; } + bool status() { return initialized; } + hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], + all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), + world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, + hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(cudaGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh new file mode 100644 index 000000000000..f2a13842c4c0 --- /dev/null +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -0,0 +1,678 @@ +#pragma once + +#include +#include "base.h" + +namespace quickreduce { + +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __quickreduce_device_inline__ CodecBase(int thread, int rank) + : thread(thread), + rank(rank), + group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } +}; + +// Default full precision codec. +template +struct CodecFP : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = + kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + __quickreduce_device_inline__ CodecFP(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into f16x8_t + int32x4_t w; + { + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + w[i] = packed_add(w[i], kHalf2_1032); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ6 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/32.0h, -1/32.0h}, fp16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, fp16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; + + __quickreduce_device_inline__ CodecQ6(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kMask00FF = 0x00FF00FF; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; + + __device__ static void run( + T const* __restrict__ input, T* __restrict__ output, + uint32_t const N, // number of elements + int const block, // block index + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + Codec codec(thread, rank); + int block_id = blockIdx.x; + int grid_size = gridDim.x; + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); + int32x4_t* src = reinterpret_cast(const_cast(input)); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + uint32_t comm_data0_offset = + data_offset + block_id * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = + grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = + grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[Codec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data0_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < Codec::kRankAtoms; i++) { + packed_assign_add(&tR[i], &tA[i]); + } + } + } + + // Phase-2: Write the reduced segment to every other rank + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data1_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + int32x4_t* dst = reinterpret_cast(output); + + for (int i = 0; i < kAtoms; i++) { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1ee..d14d30536306 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -725,6 +725,23 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); custom_ar.def("free_shared_buffer", &free_shared_buffer); +#ifdef USE_ROCM + // Quick Reduce all-reduce kernels + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr", &init_custom_qr); + custom_ar.def("qr_destroy", &qr_destroy); + + custom_ar.def("qr_get_handle", &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + // Max input size in bytes + custom_ar.def("qr_max_size", &qr_max_size); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index fae49c41d5f8..a96027005262 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import random import pytest @@ -86,7 +87,7 @@ def graph_allreduce( @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce( +def eager_custom_allreduce( monkeypatch: pytest.MonkeyPatch, tp_size, pp_size, @@ -111,19 +112,51 @@ def eager_allreduce( inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.ca_all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.ca_all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) +@ray.remote(num_gpus=1, max_calls=1) +def eager_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + os.environ["VLLM_ROCM_QR_QUANT_REGIME"] = "FP" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + sz = 1024 * 1024 + fa = get_tp_group().device_communicator.ca_comm + inp = torch.ones(sz, dtype=torch.float16, device=device) + out = inp + out = fa.qr_all_reduce(out) + torch.testing.assert_close(out, inp * tp_size) + + sz = 1024 * 1024 + inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) + out = inp + out = fa.qr_all_reduce(out) + torch.testing.assert_close(out, inp * tp_size) + + @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) -@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) +@pytest.mark.parametrize( + "test_target", + [eager_custom_allreduce, graph_allreduce, eager_quickreduce]) def test_custom_allreduce( monkeypatch: pytest.MonkeyPatch, tp_size, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff992c33b309..dc557349dbad 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1758,6 +1758,32 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +# quick all reduce +def init_custom_qr(rank: int, world_size: int) -> int: + return torch.ops._C_custom_ar.init_custom_qr(rank, world_size) + + +def qr_destroy(fa: int) -> None: + torch.ops._C_custom_ar.qr_destroy(fa) + + +def qr_all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, + quant_level: int) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level) + + +def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops._C_custom_ar.qr_get_handle(fa) + + +def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.qr_open_handles(fa, handles) + + +def qr_max_size() -> int: + return torch.ops._C_custom_ar.qr_max_size() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fcc4..1b8907bcd301 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +from enum import Enum from typing import Optional, Union import torch @@ -10,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -23,10 +25,24 @@ except Exception: # For CPUs custom_ar = False +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs + quick_ar = False logger = init_logger(__name__) +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: @@ -49,32 +65,65 @@ def is_weak_contiguous(inp: torch.Tensor): class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + _QR_SUPPORTED_WORLD_SIZES = [2, 4, 8] + _QR_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + + # TODO: We should set a reasonable range for FP. + MB = 1024 * 1024 + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } # max_size: max supported allreduce size def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + cr_max_size=8192 * 1024) -> None: """ + Custom allredcue (cr) is non-destructive acceleration, which is + available for cuda and rocm MI300 series. + Custom quick allreduce (qr) is accelerated by quantization, + currently supports fp16, Q8, Q6, Q4 quantization. + We view qr as complementary to cr, the condition for qr is + even more demanding; qr is initialized, then cr must also + be initialized. If the conditions of cr are not met, qr is + naturally not initialized. + Due to instruction set limitations, only rocm MI300 series + is supported for the time being. Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, it will be bind to f"cuda:{local_rank}". + cr_max_size: max supported size of cr. It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. """ + self._QR_SHOULD_INIT = True self._IS_CAPTURING = False self.disabled = True + self.qr_disabled = True + self.cr_max_size = cr_max_size + self.qr_max_size = ops.qr_max_size() if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-GPU environment logger.info("Custom allreduce is disabled because " "of missing custom allreduce library") - return + if not quick_ar: + logger.info("Custom quick allreduce is disabled because " + "of missing quick allreduce library") + self._QR_SHOULD_INIT = False + if not quick_ar and not custom_ar: + return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -88,10 +137,12 @@ def __init__(self, return rank = dist.get_rank(group=self.group) - self.rank = rank world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size if world_size == 1: - # No need to initialize custom allreduce for single GPU case. + # No need to initialize custom allreduce or custom quick + # allreduce for single GPU case. return if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: @@ -102,6 +153,13 @@ def __init__(self, world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) return + if self._QR_SHOULD_INIT and \ + world_size not in CustomAllreduce._QR_SUPPORTED_WORLD_SIZES: + self._QR_SHOULD_INIT = False + logger.warning( + "Custom quick allreduce is disabled due to an unsupported " + "world size: %d.", world_size) + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): @@ -131,35 +189,54 @@ def __init__(self, # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - fully_connected = current_platform.is_fully_connected( + self.fully_connected = current_platform.is_fully_connected( physical_device_ids) - if world_size > 2 and not fully_connected: + if world_size > 2 and not self.fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " "specify disable_custom_all_reduce=True explicitly.") return - # test P2P capability, this checks software/cudaruntime support - # this is expensive to compute at the first time - # then we cache the result - # On AMD GPU, p2p is always enabled between XGMI connected GPUs - if not current_platform.is_rocm() and not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks " - "GPU P2P capability or P2P test failed. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.") - return - + if not current_platform.is_rocm(): + # First, we only enable quickreduce for MI300 series, + # If it's rocm then it must be MI300 series because cr is only + # available on the mi300 series, qr must be available. + self._QR_SHOULD_INIT = False + + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True " + "explicitly.") + return + + self.init_custom_allreduce() + # self.disabled is used to indicate cr, if the condition + # of cr is not satisfied, qr must not be satisfied, + # This boolean serves as a uniform identifier for external. self.disabled = False + self.init_custom_quick_allreduce() + + def init_custom_allreduce(self): + """ + Initialize custom allreduce + """ # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group, + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + + self.cr_max_size, + group=self.group, uncached=True) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.buffer_ptrs = self.create_shared_buffer(self.cr_max_size, + group=self.group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -168,13 +245,54 @@ def __init__(self, self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.max_size = max_size - self.rank = rank - self.world_size = world_size - self.fully_connected = fully_connected - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.fully_connected) - ops.register_buffer(self._ptr, self.buffer_ptrs) + self.cr_max_size = self.cr_max_size + + self._cr_ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, + self.rank, self.fully_connected) + ops.register_buffer(self._cr_ptr, self.buffer_ptrs) + + def init_custom_quick_allreduce(self): + """ + Initialize a custom quick allreduce implementation for AMD + based on quick reduce (https://github.com/mk1-project/quickreduce). + """ + if not self._QR_SHOULD_INIT: + return + self.use_fp16_kernels = envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QR_QUANT_REGIME + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable VLLM_ROCM_QR_QUANT_REGIME") + return + + vllm_config = get_current_vllm_config() + if vllm_config is not None and \ + hasattr(vllm_config, "model_config") and \ + hasattr(vllm_config.model_config, "dtype"): + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + self._QR_SHOULD_INIT = False + # On RocM bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is not set to 1 we convert input to fp16 + if dtype == torch.bfloat16 and not self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: converting bf16 inputs to " + "fp16 can improve performance" + "set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on.") + + self.qr_quant_level = QuickReduceRegime[regime_str] + self._qr_ptr = ops.init_custom_qr(self.rank, self.world_size) + self.create_qr_shared_buffer() + self.qr_disabled = False @contextmanager def capture(self): @@ -192,7 +310,7 @@ def capture(self): self.register_graph_buffers() def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._cr_ptr) logger.info("Registering %d cuda graph addresses", len(offset)) # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. @@ -209,9 +327,32 @@ def register_graph_buffers(self): # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore - ops.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._cr_ptr, handles, offsets) - def should_custom_ar(self, inp: torch.Tensor): + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.qr_disabled: + return False + if inp.dtype not in self._QR_SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # custom quick allreduce requires input byte size to be multiples of 16 + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return inp_size <= self.qr_max_size and \ + inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ + [self.qr_quant_level.value] + + def should_custom_allreduce(self, inp: torch.Tensor): if self.disabled: return False inp_size = inp.numel() * inp.element_size() @@ -223,15 +364,20 @@ def should_custom_ar(self, inp: torch.Tensor): # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.fully_connected: - return inp_size < self.max_size + return inp_size < self.cr_max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): - """Performs an out-of-place all reduce. + def should_custom_ar(self, inp: torch.Tensor): + # Determine whether to use qr, or cr or quit + return self.should_quick_allreduce( + inp) or self.should_custom_allreduce(inp) + + def ca_all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place custom all reduce. If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered @@ -240,37 +386,61 @@ def all_reduce(self, if out is None: out = torch.empty_like(inp) if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) + ops.all_reduce(self._cr_ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce(self._cr_ptr, inp, out, self.buffer_ptrs[self.rank], + self.cr_max_size) return out + def qr_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + inp_dtype = inp.dtype + if inp_dtype == torch.bfloat16 and self.use_fp16_kernels: + inp = inp.to(torch.float16) + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._qr_ptr, inp, out, self.qr_quant_level.value) + return out.to(inp_dtype) + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): - return None - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) + # try custom quick allreduce first, then custom allreduce + if self.should_quick_allreduce(input): + # We don't need the context of quick allreduce to do graph capture + # because the ipc access is already collected in init() and + # we can capture the quick allreduce directly. + return self.qr_all_reduce(input) + + if self.should_custom_allreduce(input): + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.ca_all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) else: - # If warm up, mimic the allocation pattern since custom - # allreduce is out-of-place. - return torch.empty_like(input) - else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) + # Note: outside of cuda graph context, custom allreduce + # incurs a cost of cudaMemcpy, which should be small + # (<=1% of overall latency) compared to the performance + # gain of using custom kernels + return self.ca_all_reduce(input, registered=False) + + return None def close(self): - if not self.disabled and self._ptr: + if not self.disabled and self._cr_ptr: if ops is not None: - ops.dispose(self._ptr) - self._ptr = 0 + ops.dispose(self._cr_ptr) + self._cr_ptr = 0 self.free_shared_buffer(self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + self.disabled = True + if not self.qr_disabled and self._qr_ptr: + if ops is not None: + ops.qr_destroy(self._qr_ptr) + self._qr_ptr = 0 + self.qr_disabled = True def __del__(self): self.close() @@ -294,6 +464,17 @@ def create_shared_buffer(size_in_bytes: int, pointers.append(ops.open_mem_handle(h)) return pointers + def create_qr_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after qr_init_device_collectives + """ + handle = ops.qr_get_handle(self._qr_ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._qr_ptr, handles) + @staticmethod def free_shared_buffer(pointers: list[int], group: Optional[ProcessGroup] = None, diff --git a/vllm/envs.py b/vllm/envs.py index a4a1784f97f9..d9515e5def09 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -129,6 +129,8 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_ROCM_QR_QUANT_REGIME: str = "NONE" + VLLM_ROCM_QR_CAST_BF16_TO_FP16: bool = False def get_default_cache_root(): @@ -671,6 +673,20 @@ def get_vllm_port() -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE + # Recommended for large models to get allreduce + "VLLM_ROCM_QR_QUANT_REGIME": + lambda: os.getenv("VLLM_ROCM_QR_QUANT_REGIME", "NONE").upper(), + + # Custom quick allreduce kernel for MI3* cards + # Due to the lack of the bfloat16 asm instruction, bfloat16 + # kernels are slower than fp16, + # If environment variable is not set to 1, we convert input to fp16 + "VLLM_ROCM_QR_CAST_BF16_TO_FP16": + lambda: (os.getenv("VLLM_ROCM_QR_CAST_BF16_TO_FP16", "False").lower() in + ("true", "1")), + # If set, when running in Quark emulation mode, do not dequantize the # weights at load time. Instead, dequantize weights on-the-fly during # kernel execution.