Skip to content

Commit

Permalink
AllReduce Kernel for Small Messages (#322)
Browse files Browse the repository at this point in the history
Adding allreduce kernel code for message sizes smaller than 32 bytes,
when the number of elements are smaller than the number of ranks.

---------

Co-authored-by: Caio Rocha <[email protected]>
Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
3 people authored Jul 5, 2024
1 parent b5a48f8 commit f4c3c8f
Showing 1 changed file with 72 additions and 16 deletions.
88 changes: 72 additions & 16 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,57 @@ __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem) {
vectorSum(dst, src, nElem, blockIdx.x, gridDim.x);
}

template <typename T>
__global__ void __launch_bounds__(32, 1)
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems,
uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
const int nPeers = nRanksPerNode - 1;
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
// Double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : 4 * worldSize * nelems * sizeof(mscclpp::LL8Packet);
size_t srcOffset = channelDataOffset;
size_t scratchOffset = scratchBaseOffset + rank * nelems * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
uint32_t* src = (uint32_t*)((char*)buff);
uint32_t* dst = (uint32_t*)((char*)resultBuff);

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
const int lid = tid % WARP_SIZE;
if (lid < nPeers) {
channels[lid] = smChannels[lid];
}
__syncwarp();

// step 1: write data to each peer's scratch buffer
channels[peerIdx].putPackets<mscclpp::LL8Packet>(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid,
blockDim.x * nBlocksPerPeer, flag);

// step 2: Reduce Data
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
uint32_t data = 0;
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nelems;
uint32_t val = dstPkt[idx].read(flag, -1);
data = add_vectors<T>(val, data);
}
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;
}
}

template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelDataOffset, int rank,
int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) {
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
nelems = nelems / (sizeof(int) / sizeof(T));
Expand Down Expand Up @@ -166,7 +213,7 @@ __global__ void __launch_bounds__(1024, 1)

// step 1: write to scratch buffer
channels[peerIdx].putPackets<mscclpp::LL8Packet>(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid,
blockDim.x * nBlocksPerPeer, flag);
blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint32_t data = 0;
Expand Down Expand Up @@ -200,8 +247,8 @@ __global__ void __launch_bounds__(1024, 1)
template <typename T>
__global__ void __launch_bounds__(512, 1)
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels,size_t channelOutDataOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems) {
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset, int rank,
int nRanksPerNode, int worldSize, size_t nelems) {
const int nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
// assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
Expand All @@ -216,7 +263,8 @@ __global__ void __launch_bounds__(512, 1)

// Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
constexpr size_t unitNInt4 = 512;
const size_t maxNInt4PerBlock = (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
const size_t maxNInt4PerBlock =
(((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
size_t nInt4OfThisBlock = maxNInt4PerBlock;
size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
Expand Down Expand Up @@ -265,7 +313,7 @@ __global__ void __launch_bounds__(512, 1)
}
__syncthreads();

for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
Expand All @@ -274,7 +322,8 @@ __global__ void __launch_bounds__(512, 1)
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data);
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
}
offsetOfThisBlock += nInt4PerChunk;
Expand Down Expand Up @@ -309,32 +358,39 @@ __global__ void __launch_bounds__(512, 1)
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4), data);
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
}
}
}

template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset, size_t channelOutOffset, int rank, int nRanksPerNode,
int worldSize, size_t nelems, cudaStream_t stream) {
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
size_t channelOutOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems,
cudaStream_t stream) {
static uint32_t flag = 1;

if (sizeof(T) * nelems <= (1 << 20)) {
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
rank, nRanksPerNode, worldSize, nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
if (nelems >= 8192) {
nBlocks = 56;
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset, rank, nRanksPerNode,
worldSize, nelems, flag++);
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels, channelOutOffset, rank, nRanksPerNode,
worldSize, nelems);
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
channelOutOffset, rank, nRanksPerNode, worldSize, nelems);
}

return cudaGetLastError();
Expand Down

0 comments on commit f4c3c8f

Please sign in to comment.