From f4c3c8f9161d4f85bdcdf4a080d66e672f5701cd Mon Sep 17 00:00:00 2001 From: caiomcbr <164253795+caiomcbr@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:08:43 -0700 Subject: [PATCH] AllReduce Kernel for Small Messages (#322) Adding allreduce kernel code for message sizes smaller than 32 bytes, when the number of elements are smaller than the number of ranks. --------- Co-authored-by: Caio Rocha Co-authored-by: Changho Hwang --- apps/nccl/src/allreduce.hpp | 88 ++++++++++++++++++++++++++++++------- 1 file changed, 72 insertions(+), 16 deletions(-) diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index fd0d932c7..5f27fca78 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -129,10 +129,57 @@ __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem) { vectorSum(dst, src, nElem, blockIdx.x, gridDim.x); } +template +__global__ void __launch_bounds__(32, 1) + allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, + uint32_t flag) { + // This version of allreduce only works for single nodes + if (worldSize != nRanksPerNode) return; + if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); + const int nPeers = nRanksPerNode - 1; + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; + const int peerIdx = blockIdx.x / nBlocksPerPeer; + const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1; + // Double buffering + size_t scratchBaseOffset = (flag & 1) ? 0 : 4 * worldSize * nelems * sizeof(mscclpp::LL8Packet); + size_t srcOffset = channelDataOffset; + size_t scratchOffset = scratchBaseOffset + rank * nelems * sizeof(mscclpp::LL8Packet); + void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset); + uint32_t* src = (uint32_t*)((char*)buff); + uint32_t* dst = (uint32_t*)((char*)resultBuff); + + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + const int lid = tid % WARP_SIZE; + if (lid < nPeers) { + channels[lid] = smChannels[lid]; + } + __syncwarp(); + + // step 1: write data to each peer's scratch buffer + channels[peerIdx].putPackets(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid, + blockDim.x * nBlocksPerPeer, flag); + + // step 2: Reduce Data + for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { + uint32_t data = 0; + for (int index = 0; index < nPeers; index++) { + const int remoteRank = index < rank ? index : index + 1; + mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nelems; + uint32_t val = dstPkt[idx].read(flag, -1); + data = add_vectors(val, data); + } + data = add_vectors(data, src[idx]); + dst[idx] = data; + } +} + template __global__ void __launch_bounds__(1024, 1) - allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, size_t channelDataOffset, int rank, - int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) { + allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; nelems = nelems / (sizeof(int) / sizeof(T)); @@ -166,7 +213,7 @@ __global__ void __launch_bounds__(1024, 1) // step 1: write to scratch buffer channels[peerIdx].putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, - blockDim.x * nBlocksPerPeer, flag); + blockDim.x * nBlocksPerPeer, flag); // step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) { uint32_t data = 0; @@ -200,8 +247,8 @@ __global__ void __launch_bounds__(1024, 1) template __global__ void __launch_bounds__(512, 1) allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, - mscclpp::DeviceHandle* smOutChannels,size_t channelOutDataOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems) { + mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, int rank, + int nRanksPerNode, int worldSize, size_t nelems) { const int nPeer = nRanksPerNode - 1; const size_t chanOffset = nPeer * blockIdx.x; // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) @@ -216,7 +263,8 @@ __global__ void __launch_bounds__(512, 1) // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4` constexpr size_t unitNInt4 = 512; - const size_t maxNInt4PerBlock = (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; + const size_t maxNInt4PerBlock = + (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4; size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x; size_t nInt4OfThisBlock = maxNInt4PerBlock; size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock; @@ -265,7 +313,7 @@ __global__ void __launch_bounds__(512, 1) } __syncthreads(); - for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { + for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) { int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock]; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1; @@ -274,7 +322,8 @@ __global__ void __launch_bounds__(512, 1) } resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { - outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data); + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); } } offsetOfThisBlock += nInt4PerChunk; @@ -309,7 +358,8 @@ __global__ void __launch_bounds__(512, 1) } resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data; for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) { - outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data); + outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), + data); } } } @@ -317,24 +367,30 @@ __global__ void __launch_bounds__(512, 1) template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, - mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, size_t channelOutOffset, int rank, int nRanksPerNode, - int worldSize, size_t nelems, cudaStream_t stream) { + mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, + size_t channelOutOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, + cudaStream_t stream) { static uint32_t flag = 1; - if (sizeof(T) * nelems <= (1 << 20)) { + if (sizeof(T) * nelems < worldSize * sizeof(int)) { + int nBlocks = 7; + int nThreadsPerBlock = 32; + allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, + rank, nRanksPerNode, worldSize, nelems, flag++); + } else if (sizeof(T) * nelems <= (1 << 20)) { int nBlocks = 28; int nThreadsPerBlock = 1024; if (nelems >= 8192) { nBlocks = 56; nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; } - allreduce7<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, rank, nRanksPerNode, - worldSize, nelems, flag++); + allreduce7<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, rank, + nRanksPerNode, worldSize, nelems, flag++); } else { int nBlocks = 35; int nThreadsPerBlock = 512; - allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, channelOutOffset, rank, nRanksPerNode, - worldSize, nelems); + allreduce8<<>>(buff, scratch, resultBuff, smChannels, smOutChannels, + channelOutOffset, rank, nRanksPerNode, worldSize, nelems); } return cudaGetLastError();