diff --git a/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu new file mode 100644 index 00000000000..b04f4280a93 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.cu @@ -0,0 +1,1372 @@ +/* + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/fusedMoeCommKernels.h" + +#include + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +static __device__ __forceinline__ uint32_t __as_ptr_smem(void const* __ptr) +{ + // Consider adding debug asserts here. + return static_cast(__cvta_generic_to_shared(__ptr)); +} + +static __device__ __forceinline__ uint64_t __as_ptr_gmem(void const* __ptr) +{ + // Consider adding debug asserts here. + return static_cast(__cvta_generic_to_global(__ptr)); +} + +__device__ __forceinline__ void fence_release_sys() +{ + asm volatile("fence.release.sys;" : : : "memory"); +} + +__device__ __forceinline__ void mbarrier_init(uint64_t* addr, uint32_t const& count) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 + asm("mbarrier.init.shared.b64 [%0], %1;" : : "r"(__as_ptr_smem(addr)), "r"(count) : "memory"); +#endif +} + +__device__ __forceinline__ void mbarrier_expect_tx(uint64_t* addr, const uint32_t txCount) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm("mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" + : + : "r"(__as_ptr_smem(addr)), "r"(txCount) + : "memory"); +#endif +} + +__device__ __forceinline__ uint64_t mbarrier_arrive(uint64_t* addr) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 + uint64_t state; + asm("mbarrier.arrive.shared.b64 %0, [%1];" : "=l"(state) : "r"(__as_ptr_smem(addr)) : "memory"); + return state; +#else + return 0; +#endif +} + +__device__ __forceinline__ uint64_t mbarrier_arrive_expect_tx(uint64_t* addr, const uint32_t txCount) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + uint64_t state; + asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2;" + : "=l"(state) + : "r"(__as_ptr_smem(addr)), "r"(txCount) + : "memory"); + return state; +#else + return 0; +#endif +} + +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint64_t* addr, uint32_t const& phaseParity) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + uint32_t waitComplete; + asm("{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;\n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(__as_ptr_smem(addr)), "r"(phaseParity) + : "memory"); + return static_cast(waitComplete); +#else + return false; +#endif +} + +template +__device__ __forceinline__ void ldgsts(int* dstShm, int const* srcMem, bool predGuard) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int) predGuard), + "r"(__as_ptr_smem(dstShm)), "l"(__as_ptr_gmem(srcMem)), "n"(COPY_SIZE)); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;" : : :); +#endif +} + +template +__device__ __forceinline__ void cp_async_wait_group() +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;" : : "n"(N) : "memory"); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_g2s(void* dstMem, void const* srcMem, int copySize, uint64_t* smemBar) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + : + : "r"(__as_ptr_smem(dstMem)), "l"(__as_ptr_gmem(srcMem)), "r"(copySize), "r"(__as_ptr_smem(smemBar)) + : "memory"); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_s2g(void* dstMem, void const* srcMem, int copySize) +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" + : + : "l"(__as_ptr_gmem(dstMem)), "r"(__as_ptr_smem(srcMem)), "r"(copySize) + : "memory"); +#endif +} + +__device__ __forceinline__ void cp_async_bulk_commit_group() +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.commit_group;" : : :); +#endif +} + +template +__device__ __forceinline__ void cp_async_bulk_wait_group() +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.wait_group %0;" : : "n"(N) : "memory"); +#endif +} + +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() +{ +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 900 + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(N) : "memory"); +#endif +} + +__host__ void MoeCommFieldInfo::fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride) +{ + TLLM_CHECK(elementSize == 1 || elementSize == 2 || elementSize == 4 || elementSize == 8 || elementSize == 16); + + dataPtrBase = dataPtr; + + uint64_t dataPtrU64 = reinterpret_cast(dataPtr); + + while (elementSize < 16 && dataPtrU64 % (elementSize * 2) == 0 && vectorSize % 2 == 0 && stride % 2 == 0) + { + elementSize *= 2; + vectorSize /= 2; + stride /= 2; + } + + if (elementSize == 16) + { + alignedUnitBit = 4; + } + else if (elementSize == 8) + { + alignedUnitBit = 3; + } + else if (elementSize == 4) + { + alignedUnitBit = 2; + } + else if (elementSize == 2) + { + alignedUnitBit = 1; + } + else + { + alignedUnitBit = 0; + } + + alignedUnitCount = vectorSize; + alignedUnitStride = stride; +} + +class Ll128Proto +{ +public: + static constexpr uint32_t INITIALIZED_VALUE = 0xFFFFFFFFU; + + template + static __device__ __forceinline__ int checkDataReceivedInShm(uint8_t* sharedMemoryBase, uint64_t step, + int countIn128Bytes, int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId) + { + // return value should be how many package already been received. + // 0 means no data received, -1 means has received finish package(should be the very first 128 Byte). + uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); + int totalValidCount = 0; + for (int idxBase = loaded128ByteCount; idxBase < countIn128Bytes; idxBase += WARP_SIZE) + { + int idx = idxBase + laneId; + bool valid = false; + bool finish = false; + if (idx < countIn128Bytes) + { + int indexInFifoEntry = fifoEntry128ByteIndexBase + idx; + uint64_t value = aligned128BytesShm[idx * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + + indexInFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK]; + if (USE_FINISH) + { + finish = (value == (step & (1ULL << 63ULL))); + valid = (value == step) || finish; + } + else + { + valid = (value == step); + } + } + __syncwarp(); + unsigned validMask = __ballot_sync(WARP_MASK, valid); + // here we check valid in order, if previous valid is not true, we ignore the current valid. + int validCount = (validMask == WARP_MASK) ? WARP_SIZE : (__ffs(~validMask) - 1); + if (USE_FINISH) + { + unsigned finishedMask = __ballot_sync(WARP_MASK, finish); + // finish should be the very first 128 Byte. + if (finishedMask & 0x1) + { + return -1; + } + } + totalValidCount += validCount; + + if (validCount != WARP_SIZE) + { + break; + } + } + return totalValidCount; + } + + static __device__ __forceinline__ void protoPack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, + int fifoEntry128ByteIndexBase, int warpId, int laneId) + { + uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); + int halfLaneId = laneId % 16; + int halfIndex = laneId / 16; + int tailOffsetIn128Bytes = countIn128Bytes + halfIndex; + // for LL128 15 * 128 Bytes will be packed to 16 * 128 Bytes, each 16 threads is used for one 15 * 128 bytes. + for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30) + { + int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes; + int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; + int idxIn128Bytes = idxIn128BytesBase + halfLaneId; + int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes; + uint64_t tailValue = step; + uint64_t tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId; + if (halfLaneId == 15) + { + tailInnerIndex = tailFlagInnerIndex; + } + int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex; + if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15) + { + int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + + idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; + tailValue = aligned128BytesShm[flagIndex]; + aligned128BytesShm[flagIndex] = step; + } + aligned128BytesShm[targetTailIndex] = tailValue; + tailOffsetIn128Bytes += 2; + } + __syncwarp(); + } + + static __device__ __forceinline__ void protoUnpack(uint8_t* sharedMemoryBase, uint64_t step, int countIn128Bytes, + int fifoEntry128ByteIndexBase, int loaded128ByteCount, int warpId, int laneId) + { + uint64_t* aligned128BytesShm = reinterpret_cast(sharedMemoryBase); + int halfLaneId = laneId % 16; + int halfIndex = laneId / 16; + int tailOffsetIn128Bytes = countIn128Bytes + halfIndex; + for (int idxIn128BytesBase = halfIndex * 15; idxIn128BytesBase < countIn128Bytes; idxIn128BytesBase += 30) + { + int tailFlagIndexFromFifoEntry = fifoEntry128ByteIndexBase + tailOffsetIn128Bytes; + int tailFlagInnerIndex = tailFlagIndexFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; + int idxIn128Bytes = idxIn128BytesBase + halfLaneId; + int idxFromFifoEntry = fifoEntry128ByteIndexBase + idxIn128Bytes; + uint64_t tailValue = 0; + int tailInnerIndex = (halfLaneId >= tailFlagInnerIndex) ? halfLaneId + 1 : halfLaneId; + int targetTailIndex = tailOffsetIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + tailInnerIndex; + if (halfLaneId < 15) + { + tailValue = aligned128BytesShm[targetTailIndex]; + } + if (idxIn128Bytes < countIn128Bytes && halfLaneId < 15) + { + int flagIndex = idxIn128Bytes * MoeCommFieldInfo::UINT64_PER_128B_BLOCK + + idxFromFifoEntry % MoeCommFieldInfo::UINT64_PER_128B_BLOCK; + aligned128BytesShm[flagIndex] = tailValue; + } + tailOffsetIn128Bytes += 2; + } + __syncwarp(); + } + + static __device__ __forceinline__ void rearm( + uint32_t* u32FifoPtr, uint64_t step, int countIn128Bytes, int fifoEntry128ByteIndexBase, int warpId, int laneId) + { + // LL128 don't need rearm + } + + static __device__ __host__ __forceinline__ int computeProtoTransfer128ByteAlignedSize( + int compact128ByteSizeBeforeProto) + { + // each 15 * 128 byte need one tail 128 byte + int tail128ByteSize = (compact128ByteSizeBeforeProto + 15 * 128 - 1) / (15 * 128) * 128; + return compact128ByteSizeBeforeProto + tail128ByteSize; + } +}; + +using FusedMoeProto = Ll128Proto; + +// using FusedMoeProto = LamportProto; + +namespace fused_moe_impl +{ + +// returns copy size for txCount +__device__ __forceinline__ int startFieldG2S(MoeCommFieldInfo const& fieldInfo, int dataIndex, + uint8_t* sharedMemoryBase, int warpId, int laneId, uint64_t* smemBar) +{ + // we can copy more data than needed, just align to 16 bytes. + int alignedShmLoadOffset = fieldInfo.getUncompactShmOffset(); + uint8_t* sharedMemoryLoadPtr = sharedMemoryBase + alignedShmLoadOffset; + int copyByteCount = 0; + uint8_t* loadPtr = fieldInfo.get16BAlignedLoadCopyRange(dataIndex, ©ByteCount); + if (laneId == 0 && copyByteCount > 0) + { + cp_async_bulk_g2s(sharedMemoryLoadPtr, loadPtr, copyByteCount, smemBar); + } + return copyByteCount; +} + +__device__ __forceinline__ void startFieldS2G( + MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) +{ + int alignedShmStoreOffset = fieldInfo.getUncompactShmOffset(); + uint8_t* sharedMemoryStorePtr = sharedMemoryBase + alignedShmStoreOffset; + int copyByteCount = 0; + int headTailShmIdx; + int headTailGlobalIdx; + uint8_t* storePtr + = fieldInfo.get16BAlignedStoreCopyRange(dataIndex, ©ByteCount, laneId, &headTailShmIdx, &headTailGlobalIdx); + if (copyByteCount > 0 && laneId == 0) + { + cp_async_bulk_s2g(storePtr, sharedMemoryStorePtr + MoeCommFieldInfo::BYTES_PER_16B_BLOCK, copyByteCount); + } + if (headTailGlobalIdx >= 0) + { + // copy head and tail + fieldInfo.getRawPtr(dataIndex, nullptr)[headTailGlobalIdx] = sharedMemoryStorePtr[headTailShmIdx]; + } + __syncwarp(); +} + +// SRC_AFTER_DST is true, if src > dst, pack will use this, +// SRC_AFTER_DST is false, if src < dst, unpack will use this +template +__device__ __forceinline__ void memmoveSharedMemory(uint8_t* dst, uint8_t const* src, int copySize, int laneId) +{ + int count = (copySize + sizeof(T) - 1) / sizeof(T); + int warpLoopStart = SRC_AFTER_DST ? 0 : (count + WARP_SIZE - 1) / WARP_SIZE - 1; + int warpLoopEnd = SRC_AFTER_DST ? (count + WARP_SIZE - 1) / WARP_SIZE : -1; + int warpLoopUpdate = SRC_AFTER_DST ? 1 : -1; + for (int i = warpLoopStart; i != warpLoopEnd; i += warpLoopUpdate) + { + int idx = laneId + i * WARP_SIZE; + T data = T{}; + if (idx < count) + { + data = reinterpret_cast(src)[idx]; + } + __syncwarp(); + if (idx < count) + { + reinterpret_cast(dst)[idx] = data; + } + __syncwarp(); + } +} + +template +__device__ __forceinline__ void memmoveFieldOnSharedMemory( + MoeCommFieldInfo const& fieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) +{ + int movOffset = fieldInfo.getMemmoveOffsets(dataIndex); + if (movOffset == 0) + { + // if movOffset is 0, src and dst are the same, don't need memmove. + return; + } + int alignedBytes = 1 << fieldInfo.alignedUnitBit; + int copySize = fieldInfo.alignedUnitCount * alignedBytes; + uint8_t* sharedMemoryCompact = sharedMemoryBase + fieldInfo.getCompactShmOffset(); + uint8_t* sharedMemoryUncompact = sharedMemoryCompact + movOffset; + uint8_t* sharedMemoryDst = IS_PACK ? sharedMemoryCompact : sharedMemoryUncompact; + uint8_t* sharedMemorySrc = IS_PACK ? sharedMemoryUncompact : sharedMemoryCompact; + + if (movOffset % 16 == 0) + { + memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); + } + else if (movOffset % 8 == 0) + { + memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); + } + else if (movOffset % 4 == 0) + { + memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); + } + else if (movOffset % 2 == 0) + { + memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); + } + else + { + memmoveSharedMemory(sharedMemoryDst, sharedMemorySrc, copySize, laneId); + } +} + +template +__device__ __forceinline__ void packAllFields( + FusedMoeFieldInfo const& sendFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) +{ +#pragma unroll + for (int i = 0; i < FIELD_COUNT; i++) + { + memmoveFieldOnSharedMemory(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); + } + __syncwarp(); +} + +template +__device__ __forceinline__ void unpackAllFields( + FusedMoeFieldInfo const& recvFieldInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) +{ +#pragma unroll + for (int i = FIELD_COUNT - 1; i >= 0; i--) + { + memmoveFieldOnSharedMemory(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, laneId); + } + __syncwarp(); +} + +__device__ __forceinline__ void initSmemBar(uint64_t* smemBar, int laneId) +{ + if (laneId == 0) + { + mbarrier_init(smemBar, WARP_SIZE); + } + __syncwarp(); +} + +__device__ __forceinline__ void smemBarWait(uint64_t* smemBar, uint32_t* phaseParity) +{ + while (!mbarrier_try_wait_parity(smemBar, *phaseParity)) + { + } + *phaseParity = 1 - *phaseParity; +} + +__device__ __forceinline__ void startWorkspaceS2G( + uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int warpId, int laneId) +{ + int copyByteCount = send128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + if (laneId == 0) + { + cp_async_bulk_s2g(fifoEntry + fifo128ByteOffset * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), + sharedMemoryBase, copyByteCount); + } + __syncwarp(); + cp_async_bulk_commit_group(); +} + +__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* sharedMemoryBase, uint64_t* fifoEntry, + int allLoad128ByteCount, int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int warpId, int laneId) +{ + int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + if (laneId == 0) + { + cp_async_bulk_g2s(sharedMemoryBase + loaded128ByteCount * MoeCommFieldInfo::BYTES_PER_128B_BLOCK, + fifoEntry + + (fifo128ByteOffset + loaded128ByteCount) * MoeCommFieldInfo::BYTES_PER_128B_BLOCK / sizeof(int64_t), + copyByteCount, smemBar); + } + return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0); +} + +__device__ __forceinline__ void g2sBasicFields(FusedMoeFieldInfo const& sendFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int laneId) +{ + int topK = expertParallelInfo.topK; + int* tokenSelectedSlotsPtr = sendFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); + float* scalePtr = sendFieldInfo.getScalePtr(dataIndex, laneId, topK); + ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId, tokenSelectedSlotsPtr, laneId < topK); + ldgsts<4>(reinterpret_cast(sharedMemoryBase) + laneId + topK, reinterpret_cast(scalePtr), + laneId < topK && sendFieldInfo.expertScales != nullptr); +} + +// May commit 1 group for basic fields(tokenSelectedSlots and scales) if HAS_BASIC_FIELDS is true +// For other fields, use smemBar. +template +__device__ __forceinline__ uint64_t g2sAllFields(FusedMoeFieldInfo const& sendFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId, + uint64_t* smemBar) +{ + if (HAS_BASIC_FIELDS) + { + g2sBasicFields(sendFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, laneId); + cp_async_commit_group(); + } + int asyncLoadSize = 0; +#pragma unroll + for (int i = 0; i < FIELD_COUNT; i++) + { + asyncLoadSize + += startFieldG2S(sendFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId, smemBar); + } + return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? asyncLoadSize : 0); +} + +template +__device__ __forceinline__ void waitG2SBasicFields() +{ + if (HAS_BASIC_FIELDS) + { + cp_async_wait_group<0>(); + __syncwarp(); + } +} + +__device__ __forceinline__ void waitG2SOtherFields(uint64_t* memBar, uint32_t* phaseParity) +{ + tensorrt_llm::kernels::fused_moe_impl::smemBarWait(memBar, phaseParity); +} + +template +__device__ __forceinline__ void waitG2SAllFields(uint64_t* memBar, uint32_t* phaseParity) +{ + waitG2SBasicFields(); + waitG2SOtherFields(memBar, phaseParity); +} + +__device__ __forceinline__ void waitS2GBulkRead() +{ + cp_async_bulk_wait_group_read<0>(); + __syncwarp(); +} + +__device__ __forceinline__ void s2gBasicFields(FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) +{ + int topK = expertParallelInfo.topK; + int* tokenSelectedSlotsPtr = recvFieldInfo.getTokenSelectedSlotsPtr(dataIndex, laneId, topK); + float* scalePtr = recvFieldInfo.getScalePtr(dataIndex, laneId, topK); + if (laneId < topK) + { + int selectedSlot = reinterpret_cast(sharedMemoryBase)[laneId]; + *tokenSelectedSlotsPtr = selectedSlot; + if (recvFieldInfo.expertScales != nullptr) + { + float scale = reinterpret_cast(sharedMemoryBase)[laneId + topK]; + *scalePtr = scale; + } + } +} + +// Will commit 1 group, for all non-basic fields +template +__device__ __forceinline__ void s2gAllFields(FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int dataIndex, uint8_t* sharedMemoryBase, int warpId, int laneId) +{ + if (HAS_BASIC_FIELDS) + { + s2gBasicFields(recvFieldInfo, expertParallelInfo, dataIndex, sharedMemoryBase, warpId, laneId); + __syncwarp(); + } +#pragma unroll + for (int i = 0; i < FIELD_COUNT; i++) + { + startFieldS2G(recvFieldInfo.fieldsInfo[i], dataIndex, sharedMemoryBase, warpId, laneId); + } + cp_async_bulk_commit_group(); +} + +template +class SingleChannelCommunicator +{ +public: + __device__ __forceinline__ SingleChannelCommunicator(FusedMoeFieldInfo const& fieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, MoeSingleCommMeta const& commMeta, + FusedMoeWorkspace const& workspace, FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo, + uint64_t* smemBar, uint8_t* shmemBase) + : mFieldInfo(fieldInfo) + , mExpertParallelInfo(expertParallelInfo) + , mCommMeta(commMeta) + , mWorkspace(workspace) + , mWorldInfo(worldInfo) + , mPairInfo(pairInfo) + , mSmemBar(smemBar) + , mShmemBase(shmemBase) + { + mWarpId = threadIdx.x / WARP_SIZE; + mLaneId = threadIdx.x % WARP_SIZE; + + mFifoBasePtr = mWorkspace.getFifoBasePtr(mWorldInfo, mPairInfo); + mSenderSideFifoInfo = mWorkspace.getSenderSideFifoInfo(mWorldInfo, mPairInfo); + mReceiverSideFifoInfo = mWorkspace.getReceiverSideFifoInfo(mWorldInfo, mPairInfo); + + mSingleTransfer128ByteCount = mCommMeta.getTransfer128ByteCount(); + mSingleCompactData128ByteCount = mCommMeta.getCompactData128ByteCount(); + // initialize as need new Entry first + mFifoEntry128ByteIndexBase = kFifoEntry128ByteCount; + mFifoEntryIndex = -1; + + tensorrt_llm::kernels::fused_moe_impl::initSmemBar(mSmemBar, mLaneId); + } + + __device__ __forceinline__ uint64_t* getFifoEntryPtr() const + { + return mFifoBasePtr + mFifoEntryIndex * kFifoEntrySizeInU64; + } + + __device__ __forceinline__ bool needNewEntry() const + { + return mFifoEntry128ByteIndexBase + mSingleTransfer128ByteCount > kFifoEntry128ByteCount; + } + + __device__ __forceinline__ void nextToken() + { + mFifoEntry128ByteIndexBase += mSingleTransfer128ByteCount; + } + + __device__ __forceinline__ void senderInitFifo() + { + mHead = mSenderSideFifoInfo->head; + mTail = mSenderSideFifoInfo->tail; + } + + __device__ __forceinline__ void receiverInitFifo() + { + mHead = mReceiverSideFifoInfo->head; + mTail = mReceiverSideFifoInfo->tail; + } + + /* + * Head | 0 | 1 | 2 | 3 | 4 | 4 | 4 | 4 | 4 | 5 | + * Tail | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 3 | 4 | 4 | + * Writable | Y | Y | Y | Y | N | Y | Y | Y | Y | Y | + * Readable | N | Y | Y | Y | Y | Y | Y | Y | N | Y | + */ + + __device__ __forceinline__ void waitEntryWritable() + { + while (mTail + kFifoDepth <= mHead) + { + mTail = mSenderSideFifoInfo->tail; + } + } + + __device__ __forceinline__ void updateWriteEntry() + { + __syncwarp(); + mSenderSideFifoInfo->head = mHead; + } + + __device__ __forceinline__ void waitEntryReadable() + { + // always readable as long as flag matches. + } + + __device__ __forceinline__ void updateReadEntry() + { + mReceiverSideFifoInfo->tail = mTail; + mSenderSideFifoInfo->tail = mTail; + } + + __device__ __forceinline__ void newSendEntry() + { + mFifoEntryIndex = mHead % kFifoDepth; + mFifoEntry128ByteIndexBase = 0; + waitEntryWritable(); + __syncwarp(); + } + + __device__ __forceinline__ void newReceiveEntry() + { + mFifoEntryIndex = mTail % kFifoDepth; + mFifoEntry128ByteIndexBase = 0; + waitEntryReadable(); + __syncwarp(); + } + + __device__ __forceinline__ void doSend(int tokenCount, int* sendIndexMapping) + { + senderInitFifo(); + + int sendIndex = mPairInfo.channel; + uint32_t phaseParity = 0; + for (; sendIndex < tokenCount; sendIndex += mPairInfo.runChannelCount) + { + int tokenIndex = sendIndexMapping == nullptr ? sendIndex : sendIndexMapping[sendIndex]; + tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( + mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId, mSmemBar); + if (needNewEntry()) + { + if (mFifoEntryIndex >= 0) + { + // not first entry, update FIFO info from last entry. + mHead++; + updateWriteEntry(); + } + newSendEntry(); + } + tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(mSmemBar, &phaseParity); + tensorrt_llm::kernels::fused_moe_impl::packAllFields( + mFieldInfo, tokenIndex, mShmemBase, mLaneId); + + FusedMoeProto::protoPack( + mShmemBase, mHead, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); + + tensorrt_llm::kernels::fused_moe_impl::startWorkspaceS2G(getFifoEntryPtr(), mShmemBase, + mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); + + tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); + + nextToken(); + } + if (mFifoEntry128ByteIndexBase > 0) + { + mHead++; + updateWriteEntry(); + } + } + + __device__ __forceinline__ void rearmFifoBuffer() + { + constexpr int kUint32CountPer128Byte = 128 / sizeof(uint32_t); + uint32_t* fifoPtr = reinterpret_cast(getFifoEntryPtr()); + fifoPtr += mFifoEntry128ByteIndexBase * kUint32CountPer128Byte; + + FusedMoeProto::rearm(fifoPtr, mTail, mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, mWarpId, mLaneId); + __syncwarp(); + } + + __device__ __forceinline__ void doReceive(int tokenCount, int* recvIndexMapping) + { + receiverInitFifo(); + int recvIndex = mPairInfo.channel; + uint32_t phaseParity = 0; + bool needRelease = false; + for (; recvIndex < tokenCount; recvIndex += mPairInfo.runChannelCount) + { + int tokenIndex = recvIndexMapping == nullptr ? recvIndex : recvIndexMapping[recvIndex]; + int loaded128ByteCount = 0; + if (needNewEntry()) + { + if (mFifoEntryIndex >= 0) + { + // not first entry, update FIFO info from last entry. + mTail++; + needRelease = true; + } + newReceiveEntry(); + } + while (loaded128ByteCount < mSingleTransfer128ByteCount) + { + tensorrt_llm::kernels::fused_moe_impl::startWorkspaceG2S(mShmemBase, getFifoEntryPtr(), + mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mSmemBar, mWarpId, + mLaneId); + if (needRelease) + { + updateReadEntry(); + needRelease = false; + } + tensorrt_llm::kernels::fused_moe_impl::smemBarWait(mSmemBar, &phaseParity); + loaded128ByteCount += FusedMoeProto::template checkDataReceivedInShm(mShmemBase, mTail, + mSingleTransfer128ByteCount, mFifoEntry128ByteIndexBase, loaded128ByteCount, mWarpId, mLaneId); + } + + FusedMoeProto::protoUnpack(mShmemBase, mTail, mSingleCompactData128ByteCount, mFifoEntry128ByteIndexBase, + loaded128ByteCount, mWarpId, mLaneId); + tensorrt_llm::kernels::fused_moe_impl::unpackAllFields( + mFieldInfo, tokenIndex, mShmemBase, mLaneId); + tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( + mFieldInfo, mExpertParallelInfo, tokenIndex, mShmemBase, mWarpId, mLaneId); + tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); + + rearmFifoBuffer(); + nextToken(); + } + if (mFifoEntry128ByteIndexBase > 0) + { + mTail++; + updateReadEntry(); + } + } + +private: + static constexpr int kFifoEntrySizeInU64 = FusedMoeCommunicator::FIFO_ENTRY_BYTES / sizeof(uint64_t); + static constexpr int kFifoEntry128ByteCount = FusedMoeCommunicator::FIFO_ENTRY_128_BYTE_COUNT; + static constexpr int kFifoDepth = FusedMoeCommunicator::FIFO_DEPTH; + + FusedMoeFieldInfo mFieldInfo; + MoeExpertParallelInfo mExpertParallelInfo; + MoeSingleCommMeta mCommMeta; + FusedMoeWorkspace mWorkspace; + FusedMoeWorldInfo mWorldInfo; + FusedMoePairInfo mPairInfo; + uint64_t* mSmemBar; + uint8_t* mShmemBase; + + int mLaneId; + int mWarpId; + + uint64_t* mFifoBasePtr; + SenderSideFifoInfo* mSenderSideFifoInfo; + ReceiverSideFifoInfo* mReceiverSideFifoInfo; + + int64_t mHead; + int64_t mTail; + + int mSingleTransfer128ByteCount; + int mSingleCompactData128ByteCount; + int mFifoEntry128ByteIndexBase; + int mFifoEntryIndex; +}; + +template +__global__ void moeAllToAllKernel(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, bool hasBasicFields) +{ + __shared__ uint64_t allWarpSmemBar[32]; + extern __shared__ int4 allWarpShm[]; + + bool isSender = blockIdx.z == 0; + int runChannelCount = gridDim.y; + int group = threadIdx.y; + SendRecvIndices dataIndices = isSender ? params.sendIndices : params.recvIndices; + + FusedMoePairInfo pairInfo; + int peerRank = blockIdx.x * blockDim.y + group; + if (peerRank >= params.worldInfo.epInfo.epSize) + { + return; + } + int tokenCount; + int* groupStartPtr = dataIndices.getGroupStart(peerRank, tokenCount); + if (tokenCount == 0) + { + return; + } + + pairInfo.channel = blockIdx.y; + pairInfo.runChannelCount = runChannelCount; + pairInfo.senderRank = isSender ? params.worldInfo.epInfo.epRank : peerRank; + pairInfo.receiverRank = isSender ? peerRank : params.worldInfo.epInfo.epRank; + + if (isSender) + { + int singleShmSize = params.sendCommMeta.getSingleShmSize(); + if (hasBasicFields) + { + SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, + params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, + reinterpret_cast(allWarpShm) + singleShmSize * group); + comm.doSend(tokenCount, groupStartPtr); + } + else + { + SingleChannelCommunicator comm(params.sendFieldInfo, params.expertParallelInfo, + params.sendCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, + reinterpret_cast(allWarpShm) + singleShmSize * group); + comm.doSend(tokenCount, groupStartPtr); + } + } + else + { + int singleShmSize = params.recvCommMeta.getSingleShmSize(); + if (hasBasicFields) + { + SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, + params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, + reinterpret_cast(allWarpShm) + singleShmSize * group); + comm.doReceive(tokenCount, groupStartPtr); + } + else + { + SingleChannelCommunicator comm(params.recvFieldInfo, params.expertParallelInfo, + params.recvCommMeta, workspace, params.worldInfo, pairInfo, allWarpSmemBar + group, + reinterpret_cast(allWarpShm) + singleShmSize * group); + comm.doReceive(tokenCount, groupStartPtr); + } + } +} + +int computeMoeAlltoallMaxDynamicSharedMemorySize() +{ + int devId = -1; + TLLM_CUDA_CHECK(cudaGetDevice(&devId)); + cudaFuncAttributes attr{}; + TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, (void const*) moeAllToAllKernel<1>)); + int staticSmem = static_cast(attr.sharedSizeBytes); + int maxPerBlockShmOptin = 0; + TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxPerBlockShmOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, devId)); + return maxPerBlockShmOptin - staticSmem; +} + +} // namespace fused_moe_impl + +void FusedMoeFieldInfo::fillMetaInfo( + MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const +{ + singleCommMeta->singleCompactAlignedSize = computeSingleCompactSize(topK, hasScales, hasBasicFields); + singleCommMeta->singleUncompactAlignedSize = computeSingleUncompactSize(topK, hasScales, hasBasicFields); + singleCommMeta->singleTransferAlignedSize + = FusedMoeProto::computeProtoTransfer128ByteAlignedSize(singleCommMeta->singleCompactAlignedSize); +} + +void FusedMoeFieldInfo::fillFieldPlacementInfo(int topK, bool hasBasicFields) +{ + int basicFieldSize = 0; + if (hasBasicFields) + { + basicFieldSize = topK * sizeof(int) + (expertScales != nullptr ? topK * sizeof(float) : 0); + // align to 16 bytes + basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1) + / MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK; + } + int offset = basicFieldSize; + int unalignedFieldIndex = 0; + for (int i = 0; i < fieldCount; i++) + { + fieldsInfo[i].compact16BOffset = offset / MoeCommFieldInfo::BYTES_PER_16B_BLOCK; + offset += fieldsInfo[i].getFieldCompactSize(); + fieldsInfo[i].unalignedFieldIndex = unalignedFieldIndex; + if (fieldsInfo[i].alignedUnitBit < 4) + { + unalignedFieldIndex++; + } + } + for (int i = fieldCount; i < MOE_COMM_FIELD_MAX_COUNT; i++) + { + fieldsInfo[i].setUnused(); + } +} + +void FusedMoeWorkspace::initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo) +{ + int epSize = worldInfo.epInfo.epSize; + int epRank = worldInfo.epInfo.epRank; + size_t fifoSize = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount; + size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount; + size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount; + uint64_t* localWorkspacePtr = workspacePtr + epRank * rankStrideInU64; + TLLM_CU_CHECK(cuMemsetD32(reinterpret_cast(localWorkspacePtr), FusedMoeProto::INITIALIZED_VALUE, + fifoSize / sizeof(uint32_t))); + TLLM_CUDA_CHECK(cudaMemset( + reinterpret_cast(localWorkspacePtr) + fifoSize, 0, senderSideInfoSize + receiverSideInfoSize)); +} + +void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream) +{ + bool hasBasicFields = params.sendFieldInfo.tokenSelectedSlots != nullptr; + int warpSendShmSize = params.sendCommMeta.getSingleShmSize(); + int warpRecvShmSize = params.recvCommMeta.getSingleShmSize(); + int warpShmSize = warpSendShmSize; + int epSize = params.worldInfo.epInfo.epSize; + TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", + warpSendShmSize, warpRecvShmSize); + int maxGroupCountPerCta = std::min(params.worldInfo.epInfo.epSize, FusedMoeCommunicator::MAX_GROUP_COUNT_PER_BLOCK); + static int maxDynamicShmSize = fused_moe_impl::computeMoeAlltoallMaxDynamicSharedMemorySize(); + int groupCountPerCta = std::min(maxGroupCountPerCta, maxDynamicShmSize / warpShmSize); + + int maxFieldCount = std::max(params.sendFieldInfo.fieldCount, params.recvFieldInfo.fieldCount); + auto getFunc = [](int fieldCount) + { + switch (fieldCount) + { + case 1: return fused_moe_impl::moeAllToAllKernel<1>; + case 2: return fused_moe_impl::moeAllToAllKernel<2>; + case 3: return fused_moe_impl::moeAllToAllKernel<3>; + case 4: return fused_moe_impl::moeAllToAllKernel<4>; + case 5: return fused_moe_impl::moeAllToAllKernel<5>; + case 6: return fused_moe_impl::moeAllToAllKernel<6>; + case 7: return fused_moe_impl::moeAllToAllKernel<7>; + case 8: return fused_moe_impl::moeAllToAllKernel<8>; + default: return fused_moe_impl::moeAllToAllKernel<8>; + } + return fused_moe_impl::moeAllToAllKernel<8>; + }; + auto* kernelFn = getFunc(maxFieldCount); + + if (groupCountPerCta * warpShmSize > 48 * 1024) + { + TLLM_CUDA_CHECK(cudaFuncSetAttribute( + kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, groupCountPerCta * warpShmSize)); + } + for (; groupCountPerCta > 0; groupCountPerCta--) + { + int dynamicShmSize = groupCountPerCta * warpShmSize; + int numBlocks = 0; + if (cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &numBlocks, kernelFn, WARP_SIZE * groupCountPerCta, dynamicShmSize) + != cudaSuccess) + { + continue; + } + if (numBlocks >= 1) + { + break; + } + } + TLLM_CHECK_WITH_INFO( + groupCountPerCta >= 1, "computed groupCount=%d, warpShmSize=%d", groupCountPerCta, warpShmSize); + int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta; + groupCountPerCta = (epSize + ctaPerChannel - 1) / ctaPerChannel; + int totalDynamicShmSize = warpShmSize * groupCountPerCta; + + dim3 block = FusedMoeCommunicator::getLaunchBlockDim(groupCountPerCta); + dim3 grid = FusedMoeCommunicator::getLaunchGridDim(params.worldInfo.epInfo.epSize, groupCountPerCta); + kernelFn<<>>(params, workspace, hasBasicFields); + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +int FusedMoeCommunicator::maxSmCount = -1; +bool FusedMoeCommunicator::maxSmCountUsed = false; + +void setMaxUsableSmCount(int smCount) +{ + FusedMoeCommunicator::setMaxUsableSmCount(smCount); +} + +size_t getFusedMoeCommWorkspaceSize(int epSize) +{ + int channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); + size_t workspaceSize = FusedMoeWorkspace::computeWorkspaceSizePreRank(epSize, channelCount); + return workspaceSize; +} + +void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize) +{ + workspace->workspacePtr = workspacePtr; + workspace->rankStrideInU64 = rankStrideInU64; + workspace->channelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); +} + +void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo) +{ + workspace->initializeLocalWorkspace(worldInfo); +} + +namespace fused_moe_comm_tests +{ + +__global__ void g2sKernel(FusedMoeFieldInfo allFieldInfo, MoeExpertParallelInfo expertParallelInfo, + MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmDump, bool hasBasicFields) +{ + __shared__ uint64_t allWarpSmemBar[32]; + extern __shared__ int4 allWarpShm[]; + int laneId = threadIdx.x % WARP_SIZE; + int warpId = threadIdx.x / WARP_SIZE; + int warpCount = blockDim.x / WARP_SIZE; + int tokenIndex = warpId + blockIdx.x * warpCount; + if (tokenIndex >= tokenCount) + { + return; + } + + int singleShmSize = singleCommMeta.singleUncompactAlignedSize; + + tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId); + uint32_t phaseParity = 0; + + uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; + + if (hasBasicFields) + { + tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( + allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); + tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( + allFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); + tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); + } + + for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) + { + shmDump[tokenIndex * singleShmSize / sizeof(int) + offset] = reinterpret_cast(sharedMemoryBase)[offset]; + } +} + +void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, + int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) +{ + int warpShmSize = sendFieldInfo.computeSingleUncompactSize( + expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); + dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); + dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); + MoeSingleCommMeta singleCommMeta; + sendFieldInfo.fillMetaInfo( + &singleCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); + TLLM_CUDA_CHECK( + cudaFuncSetAttribute(g2sKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); + g2sKernel<<>>( + sendFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmDump, hasBasicFields); + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +__global__ void s2gKernel(FusedMoeFieldInfo recvFieldInfo, MoeExpertParallelInfo expertParallelInfo, + MoeSingleCommMeta singleCommMeta, int tokenCount, int* shmPreload, bool hasBasicFields) +{ + extern __shared__ int4 allWarpShm[]; + int laneId = threadIdx.x % WARP_SIZE; + int warpId = threadIdx.x / WARP_SIZE; + int warpCount = blockDim.x / WARP_SIZE; + int tokenIndex = warpId + blockIdx.x * warpCount; + if (tokenIndex >= tokenCount) + { + return; + } + int singleShmSize = singleCommMeta.singleUncompactAlignedSize; + uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; + + for (int offset = laneId; offset < singleShmSize / sizeof(int); offset += WARP_SIZE) + { + reinterpret_cast(sharedMemoryBase)[offset] + = shmPreload[tokenIndex * singleShmSize / sizeof(int) + offset]; + } + __syncwarp(); + + if (hasBasicFields) + { + tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( + recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( + recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); + } + + tensorrt_llm::kernels::fused_moe_impl::waitS2GBulkRead(); +} + +void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, + int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream) +{ + int warpShmSize = recvFieldInfo.computeSingleUncompactSize( + expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); + dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); + dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); + MoeSingleCommMeta singleCommMeta; + recvFieldInfo.fillMetaInfo( + &singleCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); + TLLM_CUDA_CHECK( + cudaFuncSetAttribute(s2gKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); + s2gKernel<<>>( + recvFieldInfo, expertParallelInfo, singleCommMeta, tokenCount, shmPreload, hasBasicFields); + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +__global__ void loopbackKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, + MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, + int* recvIndexMapping, int tokenCount, bool hasBasicFields) +{ + __shared__ uint64_t allWarpSmemBar[32]; + extern __shared__ int4 allWarpShm[]; + int laneId = threadIdx.x % WARP_SIZE; + int warpId = threadIdx.x / WARP_SIZE; + int warpCount = blockDim.x / WARP_SIZE; + int tokenIndex = warpId + blockIdx.x * warpCount; + if (tokenIndex >= tokenCount) + { + return; + } + + int recvTokenIndex = recvIndexMapping[tokenIndex]; + + tensorrt_llm::kernels::fused_moe_impl::initSmemBar(&allWarpSmemBar[warpId], laneId); + uint32_t phaseParity = 0; + + int singleShmSize = sendCommMeta.getSingleShmSize(); + + uint8_t* sharedMemoryBase = reinterpret_cast(allWarpShm) + singleShmSize * warpId; + + if (hasBasicFields) + { + tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( + sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::g2sAllFields( + sendFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId, &allWarpSmemBar[warpId]); + } + + if (hasBasicFields) + { + tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::waitG2SAllFields(&allWarpSmemBar[warpId], &phaseParity); + } + + tensorrt_llm::kernels::fused_moe_impl::packAllFields(sendFieldInfo, tokenIndex, sharedMemoryBase, laneId); + + tokenIndex = recvTokenIndex; // switch to recvTokenIndex; + + tensorrt_llm::kernels::fused_moe_impl::unpackAllFields(recvFieldInfo, tokenIndex, sharedMemoryBase, laneId); + + if (hasBasicFields) + { + tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( + recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::s2gAllFields( + recvFieldInfo, expertParallelInfo, tokenIndex, sharedMemoryBase, warpId, laneId); + } + + cp_async_bulk_wait_group_read<0>(); + __syncwarp(); +} + +// G2S -> Pack -> Unpack -> S2G +void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock, + bool hasBasicFields, cudaStream_t stream) +{ + MoeSingleCommMeta sendCommMeta, recvCommMeta; + sendFieldInfo.fillMetaInfo( + &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); + recvFieldInfo.fillMetaInfo( + &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); + int warpSendShmSize = sendCommMeta.getSingleShmSize(); + int warpRecvShmSize = recvCommMeta.getSingleShmSize(); + int warpShmSize = warpSendShmSize; + TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", + warpSendShmSize, warpRecvShmSize); + dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); + dim3 gridDim((tokenCount + warpsPerBlock - 1) / warpsPerBlock, 1, 1); + TLLM_CUDA_CHECK( + cudaFuncSetAttribute(loopbackKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); + loopbackKernel<<>>(sendFieldInfo, recvFieldInfo, + expertParallelInfo, sendCommMeta, recvCommMeta, recvIndexMapping, tokenCount, hasBasicFields); + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +template +__global__ void localFifoSendRecvKernel(FusedMoeFieldInfo sendFieldInfo, FusedMoeFieldInfo recvFieldInfo, + MoeExpertParallelInfo expertParallelInfo, MoeSingleCommMeta sendCommMeta, MoeSingleCommMeta recvCommMeta, + FusedMoeWorkspace fusedMoeWorkspace, int* sendIndexMapping, int* recvIndexMapping, int tokenCount) +{ + __shared__ uint64_t allWarpSmemBar[32]; + extern __shared__ int4 allWarpShm[]; + + FusedMoeWorldInfo worldInfo; + worldInfo.epInfo.epRank = 0; + worldInfo.epInfo.epSize = 1; + + int warpId = threadIdx.x / WARP_SIZE; + int warpCount = blockDim.x / WARP_SIZE; + + FusedMoePairInfo pairInfo; + pairInfo.senderRank = 0; + pairInfo.receiverRank = 0; + pairInfo.channel = blockIdx.z * warpCount + warpId; + pairInfo.runChannelCount = gridDim.z * warpCount; + + if (blockIdx.y == 0) + { + tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator + senderComm(sendFieldInfo, expertParallelInfo, sendCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, + &allWarpSmemBar[warpId], + reinterpret_cast(&allWarpShm[0]) + warpId * sendCommMeta.getSingleShmSize()); + senderComm.doSend(tokenCount, sendIndexMapping); + } + else + { + tensorrt_llm::kernels::fused_moe_impl::SingleChannelCommunicator + recverComm(recvFieldInfo, expertParallelInfo, recvCommMeta, fusedMoeWorkspace, worldInfo, pairInfo, + &allWarpSmemBar[warpId], + reinterpret_cast(&allWarpShm[0]) + warpId * recvCommMeta.getSingleShmSize()); + recverComm.doReceive(tokenCount, recvIndexMapping); + } +} + +void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping, + FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields, + cudaStream_t stream) +{ + MoeSingleCommMeta sendCommMeta, recvCommMeta; + sendFieldInfo.fillMetaInfo( + &sendCommMeta, expertParallelInfo.topK, sendFieldInfo.expertScales != nullptr, hasBasicFields); + recvFieldInfo.fillMetaInfo( + &recvCommMeta, expertParallelInfo.topK, recvFieldInfo.expertScales != nullptr, hasBasicFields); + int warpSendShmSize = sendCommMeta.getSingleShmSize(); + int warpRecvShmSize = recvCommMeta.getSingleShmSize(); + int warpShmSize = warpSendShmSize; + TLLM_CHECK_WITH_INFO(warpSendShmSize == warpRecvShmSize, "warpSendShmSize(%d) not same as warpRecvShmSize(%d)", + warpSendShmSize, warpRecvShmSize); + dim3 blockDim(WARP_SIZE * warpsPerBlock, 1, 1); + dim3 gridDim(1, 2, blockChannelCount); + auto* kernelFn = localFifoSendRecvKernel<>; + if (hasBasicFields) + { + kernelFn = localFifoSendRecvKernel; + } + else + { + kernelFn = localFifoSendRecvKernel; + } + TLLM_CUDA_CHECK( + cudaFuncSetAttribute(kernelFn, cudaFuncAttributeMaxDynamicSharedMemorySize, warpShmSize * warpsPerBlock)); + kernelFn<<>>(sendFieldInfo, recvFieldInfo, + expertParallelInfo, sendCommMeta, recvCommMeta, fusedMoeWorkspace, sendIndexMapping, recvIndexMapping, + tokenCount); + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace fused_moe_comm_tests + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h new file mode 100644 index 00000000000..91f0b92c6d7 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/fusedMoeCommKernels.h @@ -0,0 +1,562 @@ +/* + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/moeCommKernelsCommon.h" + +namespace tensorrt_llm +{ +namespace kernels +{ + +struct ALIGN_256 SenderSideFifoInfo +{ + volatile uint64_t head; // write position + volatile uint64_t tail; // read position +}; + +struct ALIGN_256 ReceiverSideFifoInfo +{ + volatile uint64_t head; // write position do we use this? + volatile uint64_t tail; // read position +}; + +// struct holding Send/Recv data pointer and its displacement information. +struct SendRecvIndices +{ + int const* rankCountCumSum; // length = epSize + int* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else + // rankCountCumSum[epRank] + +#ifdef __CUDACC__ + __inline__ __device__ int getCount(int rank) const + { + return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1]; + } + + __inline__ __device__ int getRankStart(int rank) const + { + return rank == 0 ? 0 : rankCountCumSum[rank - 1]; + } + + __inline__ __device__ int* getGroupStart(int rank, int& tokenCount) const + { + tokenCount = getCount(rank); + int rankStart = getRankStart(rank); + return rankLocalIndices + rankStart; + } +#endif +}; + +struct MoeCommFieldInfo +{ + uint8_t* dataPtrBase; + uint8_t alignedUnitBit; // 0, 1, 2, 3, 4 (for 1, 2, 4, 8, 16 Bytes), smallest aligned unit. + uint16_t alignedUnitCount; // data count in aligned unit + uint16_t alignedUnitStride; // data stride in aligned unit + + uint8_t unalignedFieldIndex; // the index of unaligned Field, no decrease with field index + uint16_t compact16BOffset; // aligned to 16 Bytes, offset is count of 16 Byte + + static constexpr uint64_t kAlign16BytePtrMask = (1ULL << 4) - 1; + static constexpr uint32_t kAligned16BMask = (1 << 4) - 1; + + // Constants for memory alignment and access + static constexpr int BYTES_PER_128B_BLOCK = 128; + static constexpr int INTS_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(int); + static constexpr int UINT64_PER_128B_BLOCK = BYTES_PER_128B_BLOCK / sizeof(uint64_t); + static constexpr int BYTES_PER_16B_BLOCK = 16; + // Will pad one 16 byte for each unaligned field, then head and tail 16 byte might not be aligned + + // Fill single field info, the fields that need global info is not filled here. + __host__ void fillFieldInfo(uint8_t* dataPtr, size_t elementSize, int vectorSize, int stride); + + __host__ void setUnused() + { + dataPtrBase = nullptr; + alignedUnitBit = 4; + alignedUnitCount = 0; + alignedUnitStride = 0; + unalignedFieldIndex = 0; + compact16BOffset = 0; + } + + template + __host__ void fillFieldInfo(T* dataPtr, int vectorSize, int stride) + { + size_t elementSize = sizeof(T); + fillFieldInfo(reinterpret_cast(dataPtr), elementSize, vectorSize, stride); + } + + __device__ __host__ __forceinline__ int getFieldUncompactSize() const + { + int alignedUnitBytes = 1 << alignedUnitBit; + int currentFieldSize = alignedUnitCount * alignedUnitBytes; + if (alignedUnitBytes != 16) + { + constexpr int alignedUnitBytes = BYTES_PER_16B_BLOCK; + currentFieldSize = currentFieldSize / alignedUnitBytes * alignedUnitBytes; + currentFieldSize += alignedUnitBytes * 2; + } + return currentFieldSize; + } + + __device__ __host__ __forceinline__ int getFieldCompactSize() const + { + int alignedUnitBytes = 1 << alignedUnitBit; + int currentFieldSize = alignedUnitCount * alignedUnitBytes; + // Align to 16 bytes for compact size + return (currentFieldSize + BYTES_PER_16B_BLOCK - 1) / BYTES_PER_16B_BLOCK * BYTES_PER_16B_BLOCK; + } + + __device__ __forceinline__ int getCompactShmOffset() const + { + return compact16BOffset * BYTES_PER_16B_BLOCK; + } + + __device__ __forceinline__ int getUncompactShmOffset() const + { + // each unaligned field need 16 byte head and 16 byte tail + return compact16BOffset * BYTES_PER_16B_BLOCK + unalignedFieldIndex * BYTES_PER_16B_BLOCK; + } + + __device__ __forceinline__ int getMemmoveOffsets(int index) const + { + int alignedBytes = 1 << alignedUnitBit; + uint8_t* dataPtr = dataPtrBase + index * alignedBytes * alignedUnitStride; + int offset = reinterpret_cast(dataPtr) & kAlign16BytePtrMask; + return offset + unalignedFieldIndex * BYTES_PER_16B_BLOCK; + } + + __device__ __forceinline__ uint8_t* getRawPtr(int index, int* rawSize) const + { + int alignedBytes = 1 << alignedUnitBit; + uint8_t* dataPtr = dataPtrBase + static_cast(index) * alignedBytes * alignedUnitStride; + if (rawSize != nullptr) + { + *rawSize = alignedUnitCount * alignedBytes; + } + return dataPtr; + } + + __device__ __forceinline__ uint8_t* get16BAlignedLoadCopyRange(int index, int* copyByteCount) const + { + int rawSize; + uint8_t* rawDataPtr = getRawPtr(index, &rawSize); + uint8_t* rawEndPtr = rawDataPtr + rawSize; + uint8_t* alignedDataPtr + = reinterpret_cast(reinterpret_cast(rawDataPtr) & (~kAlign16BytePtrMask)); + uint32_t copySize = rawEndPtr - alignedDataPtr; + *copyByteCount + = (copySize & kAligned16BMask) != 0 ? (copySize & (~kAligned16BMask)) + BYTES_PER_16B_BLOCK : copySize; + return alignedDataPtr; + } + + __device__ __forceinline__ uint8_t* get16BAlignedStoreCopyRange( + int index, int* copyByteCount, int laneId, int* headTailShmIdx, int* headTailGlobalIdx) const + { + int rawSize; + uint8_t* rawDataPtr = getRawPtr(index, &rawSize); + uint8_t* rawEndPtr = rawDataPtr + rawSize; + int offset = reinterpret_cast(rawDataPtr) & kAlign16BytePtrMask; + uint8_t* alignedDataPtr + = reinterpret_cast(reinterpret_cast(rawDataPtr) + BYTES_PER_16B_BLOCK - offset); + uint8_t* alignedEndPtr + = reinterpret_cast(reinterpret_cast(rawEndPtr) & (~kAlign16BytePtrMask)); + int alignedCopyBytes = alignedEndPtr - alignedDataPtr; + if (alignedCopyBytes < 0) + { + alignedCopyBytes = 0; + } + *copyByteCount = alignedCopyBytes; + + if (laneId < BYTES_PER_16B_BLOCK) + { + *headTailShmIdx = laneId; + } + else + { + *headTailShmIdx = laneId + alignedCopyBytes; + } + *headTailGlobalIdx = *headTailShmIdx - offset; + if (*headTailGlobalIdx < 0 || *headTailGlobalIdx >= rawSize) + { + *headTailGlobalIdx = -1; + *headTailShmIdx = -1; + } + return alignedDataPtr; + } +}; + +// Maximum number of field supported, except tokenSelectedExpert and expertScales +static constexpr int MOE_COMM_FIELD_MAX_COUNT = 8; + +struct MoeSingleCommMeta +{ + int singleTransferAlignedSize; // transfer size aligned to 128 bytes. + int singleCompactAlignedSize; // compact buffer is always aligned to 128 bytes + int singleUncompactAlignedSize; // uncompact shared memory size, aligned to 128 bytes, might be larger than compact + // buffer if unaligned field exist. + + // TODO: Do we need reduce shared memory usage, make it able to be smaller, and enable multiple wave? + + __device__ __host__ __forceinline__ int getTransfer128ByteCount() const + { + return singleTransferAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + } + + __device__ __host__ __forceinline__ int getCompactData128ByteCount() const + { + return singleCompactAlignedSize / MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + } + + __device__ __host__ __forceinline__ int getSingleShmSize() const + { + return std::max(singleUncompactAlignedSize, singleTransferAlignedSize); + } +}; + +struct FusedMoeWorldInfo +{ + MoeEpWorldInfo epInfo; +}; + +struct FusedMoePairInfo +{ + int senderRank; + int receiverRank; + int channel; + int runChannelCount; +}; + +class FusedMoeCommunicator +{ +public: + static constexpr int FIFO_DEPTH = 4; + static constexpr int FIFO_ENTRY_BYTES = 256 * 1024; + static constexpr int FIFO_ENTRY_128_BYTE_COUNT = FIFO_ENTRY_BYTES / 128; + static constexpr int FIFO_TOTAL_BYTES = FIFO_ENTRY_BYTES * FIFO_DEPTH; + static constexpr int FIFO_TOTAL_U64 = FIFO_TOTAL_BYTES / sizeof(uint64_t); + static constexpr int MAX_GROUP_COUNT_PER_BLOCK = 8; + + static constexpr int WARP_SIZE = 32; + + static int maxSmCount; + static bool maxSmCountUsed; + + static void setMaxUsableSmCount(int maxUsableSmCount) + { + TLLM_CHECK_WITH_INFO( + FusedMoeCommunicator::maxSmCountUsed == false, "setMaxUsableSmCount can be called only before it is used"); + int smCount = tensorrt_llm::common::getMultiProcessorCount(); + if (maxUsableSmCount > smCount) + { + TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead", + maxUsableSmCount, smCount); + maxUsableSmCount = smCount; + } + FusedMoeCommunicator::maxSmCount = maxUsableSmCount; + } + + static int getMaxUsableSmCount() + { + FusedMoeCommunicator::maxSmCountUsed = true; + if (FusedMoeCommunicator::maxSmCount == -1) + { + int smCount = tensorrt_llm::common::getMultiProcessorCount(); + FusedMoeCommunicator::maxSmCount = smCount; + } + return FusedMoeCommunicator::maxSmCount; + } + + static int computeMoeCommChannelCount(int epSize) + { + int smCount = getMaxUsableSmCount(); + int blockCountPerChannel = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK; + blockCountPerChannel *= 2; // for send and recv + TLLM_CHECK_WITH_INFO( + blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount); + int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication + int channelCount = std::max(perferredChannel, 1); // at lease one channel + return channelCount; + } + + static int getMoeCommChannelCount(int epSize) + { + static std::map channelCountMap{}; + auto iter = channelCountMap.find(epSize); + if (iter == channelCountMap.end()) + { + auto channelCount = FusedMoeCommunicator::computeMoeCommChannelCount(epSize); + channelCountMap[epSize] = channelCount; + return channelCount; + } + return iter->second; + } + + static dim3 getLaunchBlockDim(int groupCountPerCta) + { + return dim3(WARP_SIZE, groupCountPerCta); + } + + static dim3 getLaunchGridDim(int epSize, int groupCountPerCta) + { + int maxChannelCount = FusedMoeCommunicator::getMoeCommChannelCount(epSize); + int targetCtaCount = (epSize + MAX_GROUP_COUNT_PER_BLOCK - 1) / MAX_GROUP_COUNT_PER_BLOCK * maxChannelCount * 2; + int ctaPerChannel = (epSize + groupCountPerCta - 1) / groupCountPerCta; + int ctaLimitedChannelCount = targetCtaCount / 2 / ctaPerChannel; + ctaLimitedChannelCount = std::max(1, ctaLimitedChannelCount); + int channelCount = std::min(ctaLimitedChannelCount, maxChannelCount); + return dim3(ctaPerChannel, channelCount, 2); + } +}; + +size_t getFusedMoeCommWorkspaceSize(int epSize); + +struct FusedMoeFieldInfo +{ + int8_t isBasicInterleaved; // using tokenSelectedSlots and expertScales interleaving? + int32_t* tokenSelectedSlots; + float* expertScales; // can be nullptr if no scale is used(all 1.0), if so, interleaved should all be 0 + int fieldCount; + MoeCommFieldInfo fieldsInfo[MOE_COMM_FIELD_MAX_COUNT]; + + __host__ int computeSingleCompactSize(int topK, bool hasScales, bool hasBasicFields) const + { + int basicFieldSize = 0; + if (hasBasicFields) + { + basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0); + // align to 16 bytes + basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1) + / MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK; + } + int otherFieldSize = 0; + for (int i = 0; i < fieldCount; i++) + { + MoeCommFieldInfo const& fieldInfo = fieldsInfo[i]; + otherFieldSize += fieldInfo.getFieldCompactSize(); + } + int totalSize = basicFieldSize + otherFieldSize; + constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment; + return totalSize; + } + + __host__ int computeSingleUncompactSize(int topK, bool hasScales, bool hasBasicFields) const + { + int basicFieldSize = 0; + if (hasBasicFields) + { + basicFieldSize = topK * sizeof(int) + (hasScales ? topK * sizeof(float) : 0); + // align to 16 bytes + basicFieldSize = (basicFieldSize + MoeCommFieldInfo::BYTES_PER_16B_BLOCK - 1) + / MoeCommFieldInfo::BYTES_PER_16B_BLOCK * MoeCommFieldInfo::BYTES_PER_16B_BLOCK; + } + int otherFieldSize = 0; + for (int i = 0; i < fieldCount; i++) + { + MoeCommFieldInfo const& fieldInfo = fieldsInfo[i]; + otherFieldSize += fieldInfo.getFieldUncompactSize(); + } + int totalSize = basicFieldSize + otherFieldSize; + constexpr int totalSizeAlignment = MoeCommFieldInfo::BYTES_PER_128B_BLOCK; + totalSize = (totalSize + totalSizeAlignment - 1) / totalSizeAlignment * totalSizeAlignment; + return totalSize; + } + + template + __device__ __forceinline__ T* getBasicFieldPtr(int tokenIndex, int selectedIndex, int topK) const + { + T* fieldPtr = nullptr; + fieldPtr = IS_SLOTS ? reinterpret_cast(tokenSelectedSlots) : reinterpret_cast(expertScales); + if (fieldPtr == nullptr || selectedIndex >= topK) + { + return nullptr; + } + int tokenStride = isBasicInterleaved ? topK * 2 : topK; + int elementStride = isBasicInterleaved ? 2 : 1; + return fieldPtr + tokenIndex * tokenStride + selectedIndex * elementStride; + } + + __device__ __forceinline__ int* getTokenSelectedSlotsPtr(int tokenIndex, int selectedIndex, int topK) const + { + return getBasicFieldPtr(tokenIndex, selectedIndex, topK); + } + + __device__ __forceinline__ float* getScalePtr(int tokenIndex, int selectedIndex, int topK) const + { + return getBasicFieldPtr(tokenIndex, selectedIndex, topK); + } + + void fillMetaInfo(MoeSingleCommMeta* singleCommMeta, int topK, bool hasScales, bool hasBasicFields) const; + + void fillFieldPlacementInfo(int topK, bool hasBasicFields); +}; + +struct FusedMoeCommKernelParam +{ + FusedMoeWorldInfo worldInfo; + MoeExpertParallelInfo expertParallelInfo; // expertCount inside should be slotCount if using redundant experts. + MoeSingleCommMeta sendCommMeta; + MoeSingleCommMeta recvCommMeta; + SendRecvIndices sendIndices; + SendRecvIndices recvIndices; + FusedMoeFieldInfo sendFieldInfo; + FusedMoeFieldInfo recvFieldInfo; +}; + +/* + * Workspace Layout: + * Ri: Rank i + * N: Number of GPUs, e.g. EpSize or WorldSize, n = N - 1 + * Ci: Channel i + * M: Number of Channels, m = M - 1 + * MMr: Memory Mapped from Rank r, physically located at rank r, and mapped to all ranks. + * + * Whole workspace memory space: + * --------------------------------------------------------------------------------------------------- + * |<-- MM0 --> |<-- MM1 --> |<-- MM2 --> | ...... |<-- MMn --> | + * ^ ^ ^ ^ ^ ^ + * 0 rankStrideInU64 2*rankStrideInU64 3*rankStrideInU64 n*rankStrideInU64 N*rankStrideInU64 + * + * For each MMr, the layout is: + * ------------------------------------------------------------------------------------------------- + * |<--- FIFO memory --->|<--- SenderSideFifoInfo memory --->|<--- ReceiverSideFifoInfo memory --->| + * ------------------------------------------------------------------------------------------------- + * + * For each FIFO memory, it is physically placed at the receiver rank. + * To find the FIFO whose receiver is rank r, we need to find that in the FIFO memory of MMr. + * The layout of FIFO memory of each MMR is(here rank is the sender rank): + * ------------------------------------------------------------------------------------------------- + * | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm | + * |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->| + * ------------------------------------------------------------------------------------------------- + * Each R*C* has length of FIFO_TOTAL_U64 in uint64_t, which is internally divided into FIFO_DEPTH entries of + * size FIFO_ENTRY_BYTES each. + * + * For each SenderSideFifoInfo memory, it is physically placed at the sender rank. + * To find the SenderSideFifoInfo whose sender is rank r, we need to find that in the FIFO memory of MMr. + * The layout of SenderSideFifoInfo memory of each MMR is(here rank is the receiver rank): + * ------------------------------------------------------------------------------------------------- + * | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm | + * |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->| + * ------------------------------------------------------------------------------------------------- + * Each R*C* is one struct of SenderSideFifoInfo. There are total M * N SenderSideFifoInfo in each MMR. + * + * For each ReceiverSideFifoInfo memory, it is physically placed at the receiver rank. + * To find the ReceiverSideFifoInfo whose receiver is rank r, we need to find that in the FIFO memory of MMr. + * The layout of ReceiverSideFifoInfo memory of each MMR is(here rank is the sender rank): + * ------------------------------------------------------------------------------------------------- + * | R0C0 | R0C1 | .... | R0Cm | R1C0 | R1C1 | .... | R1Cm | .... .... | RnC0 | RnC1 | .... | RnCm | + * |<- Channels for Rank 0 ->|<- Channels for Rank 1 ->| |<- Channels for Rank n ->| + * ------------------------------------------------------------------------------------------------- + * Each R*C* is one struct of ReceiverSideFifoInfo. There are total M * N ReceiverSideFifoInfo in each MMR. + */ + +struct FusedMoeWorkspace +{ + uint64_t* workspacePtr; + size_t rankStrideInU64; + int channelCount; + + template + __device__ __forceinline__ uint8_t* commonGetPtrBase( + FusedMoePairInfo const& pairInfo, size_t fieldOffset, int fieldSingleSize) const + { + int mappedMemoryrank = isSenderSideBuffer ? pairInfo.senderRank : pairInfo.receiverRank; + int rankInsideMappedMemory = isSenderSideBuffer ? pairInfo.receiverRank : pairInfo.senderRank; + auto* mappedMemory = reinterpret_cast(workspacePtr + mappedMemoryrank * rankStrideInU64); + mappedMemory += fieldOffset; + mappedMemory += rankInsideMappedMemory * channelCount * fieldSingleSize; + mappedMemory += pairInfo.channel * fieldSingleSize; + return mappedMemory; + } + + __device__ __forceinline__ uint64_t* getFifoBasePtr( + FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const + { + constexpr int fieldSingleSize = FusedMoeCommunicator::FIFO_TOTAL_BYTES; + return reinterpret_cast(commonGetPtrBase(pairInfo, 0, fieldSingleSize)); + } + + __device__ __forceinline__ SenderSideFifoInfo* getSenderSideFifoInfo( + FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const + { + constexpr int fieldSingleSize = sizeof(SenderSideFifoInfo); + size_t fieldOffset + = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount; + return reinterpret_cast(commonGetPtrBase(pairInfo, fieldOffset, fieldSingleSize)); + } + + __device__ __forceinline__ ReceiverSideFifoInfo* getReceiverSideFifoInfo( + FusedMoeWorldInfo const& worldInfo, FusedMoePairInfo const& pairInfo) const + { + constexpr int fieldSingleSize = sizeof(ReceiverSideFifoInfo); + size_t fieldOffset + = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * worldInfo.epInfo.epSize * channelCount + + sizeof(SenderSideFifoInfo) * worldInfo.epInfo.epSize * channelCount; + return reinterpret_cast(commonGetPtrBase(pairInfo, fieldOffset, fieldSingleSize)); + } + + static size_t computeWorkspaceSizePreRank(int epSize, int channelCount) + { + size_t fifoSize = static_cast(FusedMoeCommunicator::FIFO_TOTAL_BYTES) * epSize * channelCount; + size_t senderSideInfoSize = sizeof(SenderSideFifoInfo) * epSize * channelCount; + size_t receiverSideInfoSize = sizeof(ReceiverSideFifoInfo) * epSize * channelCount; + return fifoSize + senderSideInfoSize + receiverSideInfoSize; + } + + void initializeLocalWorkspace(FusedMoeWorldInfo const& worldInfo); +}; + +void setMaxUsableSmCount(int smCount); + +void moeAllToAll(FusedMoeCommKernelParam params, FusedMoeWorkspace workspace, cudaStream_t stream); + +void constructWorkspace(FusedMoeWorkspace* workspace, uint64_t* workspacePtr, size_t rankStrideInU64, int epSize); + +void initializeFusedMoeLocalWorkspace(FusedMoeWorkspace* workspace, FusedMoeWorldInfo const& worldInfo); + +namespace fused_moe_comm_tests +{ + +// Functions for testing + +void launchSingleG2S(FusedMoeFieldInfo const& sendFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, + int tokenCount, int* shmDump, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream); + +void launchSingleS2G(FusedMoeFieldInfo const& recvFieldInfo, MoeExpertParallelInfo const& expertParallelInfo, + int tokenCount, int* shmPreload, int warpsPerBlock, bool hasBasicFields, cudaStream_t stream); + +void launchLoopback(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int* recvIndexMapping, int tokenCount, int warpsPerBlock, + bool hasBasicFields, cudaStream_t stream); + +void launchLocalFifoSendRecv(FusedMoeFieldInfo const& sendFieldInfo, FusedMoeFieldInfo const& recvFieldInfo, + MoeExpertParallelInfo const& expertParallelInfo, int* sendIndexMapping, int* recvIndexMapping, + FusedMoeWorkspace fusedMoeWorkspace, int tokenCount, int warpsPerBlock, int blockChannelCount, bool hasBasicFields, + cudaStream_t stream); + +} // namespace fused_moe_comm_tests + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/moeCommKernels.cu b/cpp/tensorrt_llm/kernels/moeCommKernels.cu deleted file mode 100644 index 66cdacf5163..00000000000 --- a/cpp/tensorrt_llm/kernels/moeCommKernels.cu +++ /dev/null @@ -1,804 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "moeCommKernels.h" - -#include - -#include -#include - -namespace cg = cooperative_groups; - -namespace tensorrt_llm::kernels -{ - -__device__ inline void barrier_sync(int name, int nThreads) -{ - asm volatile("barrier.sync.aligned %0, %1;" ::"r"(name), "r"(nThreads) : "memory"); -} - -inline __device__ void load128(uint64_t const* ptr, uint64_t& v0, uint64_t& v1) -{ - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v0), "=l"(v1) : "l"(ptr) : "memory"); -} - -inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) -{ - asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" ::"l"(v0), "l"(v1), "l"(ptr) : "memory"); -} - -template -class AllToAllChannelCommunicator : public AllToAllChannelCommunicatorBase -{ -private: - int const tid; // thread index in primitives group - int const nthreads; // number of threads in primitives group - int const wid; // lane index in warp - int const warp; // warp index in primitives group - const MoeEpWorldInfo worldInfo; - const MoeCommWorkspace workspace; - const SendRecvDataInfo sendRecvDataInfo; - const SendRecvDispls dataDispls; - int peerRank; // peer rank index - bool const flagThread; - int const group; // primitives group index - int const channel; // channel index - int const channelCount; // count of channels - - MoeCommFifoConnInfo* fifoConnInfoPtr; - uint64_t* fifoBasePtr; // pointer to fifo base address - uint64_t step; - uint64_t tailStepCache; - uint64_t regs[U64_DATA_REG_PER_THREAD]; - GroupSharedBuffer* groupSharedBuffer; - - int groupStartIndice; - int groupEndIndice; - - int sliceStartIndice; - int sliceEndIndice; - - uint64_t* stepFifoEntryPtr; - -public: - __inline__ __device__ uint64_t getFlag() - { - return step + 1; - } - - __inline__ __device__ AllToAllChannelCommunicator(MoeEpWorldInfo const& worldInfo, MoeCommWorkspace workspace, - SendRecvDataInfo sendRecvDataInfo, SendRecvDispls dataDispls, GroupSharedBuffer* groupSharedBuffer, - int channelCount) - : worldInfo(worldInfo) - , nthreads(blockDim.x) - , tid(threadIdx.x) - , workspace(workspace) - , sendRecvDataInfo(sendRecvDataInfo) - , dataDispls(dataDispls) - , wid(threadIdx.x % WARP_SIZE) - , warp(threadIdx.x / WARP_SIZE) - , peerRank(blockIdx.x * GROUP_COUNT_PER_BLOCK + threadIdx.y) - , group(threadIdx.y) - , channel(blockIdx.y) - , flagThread(threadIdx.x % 8 == 7) - , fifoConnInfoPtr(nullptr) - , fifoBasePtr(nullptr) - , step(0) - , tailStepCache(0) - , groupSharedBuffer(groupSharedBuffer) - , channelCount(channelCount) - { - } - - __inline__ __device__ void init() - { - fifoBasePtr = workspace.getFifoBasePtr(isSender, worldInfo.epRank, peerRank, channel, channelCount); - fifoConnInfoPtr - = workspace.getFifoConnInfo(isSender, worldInfo.epRank, peerRank, channel, worldInfo.epSize, channelCount); - step = isSender ? fifoConnInfoPtr->head : fifoConnInfoPtr->tail; - tailStepCache = isSender ? fifoConnInfoPtr->tail : 0; - } - - __inline__ __device__ void computeGroupTransferRange() - { - if (tid == 0) - { - int rankCount = dataDispls.getCount(peerRank); - int rankStart = dataDispls.getRankStart(peerRank); - int countPerChannel = (rankCount + channelCount - 1) / channelCount; - int groupEnd = min(rankStart + (channel + 1) * countPerChannel, rankStart + rankCount); - int groupStart = min(rankStart + channel * countPerChannel, rankStart + rankCount); - groupSharedBuffer->groupStartIndice = groupStart; - groupSharedBuffer->groupEndIndice = groupEnd; - } - barrier(); - groupStartIndice = groupSharedBuffer->groupStartIndice; - groupEndIndice = groupSharedBuffer->groupEndIndice; - } - - __inline__ __device__ void loadTransferIndices() - { - sliceStartIndice = groupStartIndice; - sliceEndIndice = min(groupStartIndice + sendRecvDataInfo.vectorCountPerFifoEntry, groupEndIndice); - for (int i = groupStartIndice + tid; i < sliceEndIndice; i += WARP_SIZE * WARP_PER_GROUP) - { - groupSharedBuffer->groupIndiceBuffer[i - groupStartIndice] = dataDispls.getRealVectorIndice(i); - } - groupStartIndice = sliceEndIndice; - barrier(); - } - - __inline__ __device__ void computeSlicePtr() - { - stepFifoEntryPtr = fifoBasePtr + RECV_FIFO_ENTRY_U64 * (step % RECV_FIFO_DEPTH); - } - - __inline__ __device__ void sendSlice() - { - waitSend(); - int EltPer16B = 2; - int eltN = sendRecvDataInfo.vectorSizeInU64; - for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP) - { - int idxInSlice = vecId - sliceStartIndice; - int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice]; - uint64_t* src = dataDispls.getVectorDataPtr(vecRealIdx); - uint64_t* slicePtr = stepFifoEntryPtr - + idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid; - for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++) - { - int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64; -#pragma unroll - for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++) - { - int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8); - __syncwarp(); - if (!flagThread || g % 2 == 0) - { - if (ix * EltPer16B + vecOff < eltN) - { - load128((uint64_t*) (src + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]); - } - } - __syncwarp(); - } -#pragma unroll - for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2) - { - if (flagThread) - regs[2 * g] = regs[2 * g - 1]; - } - - uint64_t flag = getFlag(); - uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64; - __syncwarp(); -#pragma unroll - for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2) - { - store128(packetPtr + u * WARP_SIZE, regs[u], flagThread ? flag : regs[u + 1]); - } - } - } - updateSend(); - } - - __inline__ __device__ void recvSlice() - { - // receiver don't need to wait since we have flag. - int EltPer16B = 2; - int eltN = sendRecvDataInfo.vectorSizeInU64; - for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP) - { - int idxInSlice = vecId - sliceStartIndice; - int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice]; - - uint64_t* dst = dataDispls.getVectorDataPtr(vecRealIdx); - uint64_t* slicePtr = stepFifoEntryPtr - + idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid; - for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++) - { - uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64; - int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64; - - bool needReload; - uint64_t flag = getFlag(); - __syncwarp(); - do - { - needReload = false; -#pragma unroll - for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2) - { - load128(packetPtr + u * WARP_SIZE, regs[u], regs[u + 1]); - needReload |= flagThread && (regs[u + 1] != flag); - } - } while (__any_sync(WARP_MASK, needReload)); -#pragma unroll - for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2) - { - if (flagThread) - regs[2 * g - 1] = regs[2 * g]; - } - -#pragma unroll - for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++) - { - int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8); - __syncwarp(); - if (!flagThread || g % 2 == 0) - { - if (ix * EltPer16B + vecOff < eltN) - { - store128((uint64_t*) (dst + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]); - } - } - __syncwarp(); - } - } - } - updateRecv(); - } - - __inline__ __device__ void run() - { - if (peerRank >= worldInfo.epSize) - { - return; - } - init(); - computeGroupTransferRange(); - while (groupStartIndice < groupEndIndice) - { - loadTransferIndices(); - computeSlicePtr(); - if (isSender) - { - sendSlice(); - } - else - { - recvSlice(); - } - } - } - - __inline__ __device__ ~AllToAllChannelCommunicator() {} - - __inline__ __device__ void barrier() - { - barrier_sync(15 - group, nthreads); - } - - __inline__ __device__ void waitSend() - { - barrier(); - while (tailStepCache + RECV_FIFO_DEPTH < step + 1) - { - tailStepCache = fifoConnInfoPtr->tail; - } - barrier(); - } - - __inline__ __device__ void updateSend() - { - barrier(); - if (tid == 0) - { - atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->head, 1); - } - barrier(); - step++; - } - - __inline__ __device__ void updateRecv() - { - barrier(); - if (tid == 0) - { - atomicAdd_system((unsigned long long*) &fifoConnInfoPtr->tail, 1); - } - barrier(); - step++; - } -}; - -__global__ void moeAllToAllKernel(MoeEpWorldInfo worldInfo, MoeCommWorkspace workspace, - SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls, SendRecvDispls recvDispls) -{ - __shared__ AllToAllChannelCommunicatorBase::GroupSharedBuffer - allGroupSharedBuffer[AllToAllChannelCommunicatorBase::GROUP_COUNT_PER_BLOCK]; - bool isSender = blockIdx.z == 0; - int channelCount = gridDim.y; - int group = threadIdx.y; - SendRecvDispls dataDispls = isSender ? sendDispls : recvDispls; - AllToAllChannelCommunicatorBase::GroupSharedBuffer* groupSharedBuffer = &allGroupSharedBuffer[group]; - if (isSender) - { - AllToAllChannelCommunicator comm( - worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount); - comm.run(); - } - else - { - AllToAllChannelCommunicator comm( - worldInfo, workspace, sendRecvDataInfo, dataDispls, groupSharedBuffer, channelCount); - comm.run(); - } -} - -void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls, - SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream) -{ - sendRecvDataInfo.DoPreCompute(); - TLLM_CHECK_WITH_INFO( - reinterpret_cast(sendDispls.dataPtr) % 16 == 0, "sendDispls.dataPtr must be 16-byte aligned"); - TLLM_CHECK_WITH_INFO( - reinterpret_cast(recvDispls.dataPtr) % 16 == 0, "recvDispls.dataPtr must be 16-byte aligned"); - dim3 block = AllToAllChannelCommunicatorBase::getLaunchBlockDim(); - dim3 grid = AllToAllChannelCommunicatorBase::getLaunchGridDim(worldInfo.epSize); - moeAllToAllKernel<<>>(worldInfo, workspace, sendRecvDataInfo, sendDispls, recvDispls); -} - -template -__inline__ __device__ void computeSendRecvRankCountDevice(MoeEpWorldInfo worldInfo, - MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum, - int const* gatheredTargetRankIds, int* sharedSendRecvRankCount, int* sendRecvRankCount) -{ - cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); - int laneInTile = tile.thread_rank(); - int tileId = threadIdx.x / kThreadsGroupSize; - int tileCountPerBlock = blockDim.x / kThreadsGroupSize; - - int topK = expertParallelInfo.topK; - int epRank = worldInfo.epRank; - int epSize = worldInfo.epSize; - - if (threadIdx.x == 0) - { - *sharedSendRecvRankCount = 0; - } - - __syncthreads(); - int readRank = isSend ? epRank : blockIdx.x; - int compareRankId = isSend ? blockIdx.x : epRank; - int const* readRankTargetRankIds = gatheredTargetRankIds + readRank * maxTokenCountPerRank * topK; - int readRankTokenCount = maxTokenCountPerRank; - if (realRankTokenCountCumSum != nullptr) - { - int readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1]; - readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK; - readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart; - } - - for (int i = tileId + blockIdx.z * tileCountPerBlock; i < readRankTokenCount; i += tileCountPerBlock * gridDim.z) - { - int targetRankId = laneInTile < topK ? readRankTargetRankIds[i * topK + laneInTile] : epSize; - bool rankMatched = (targetRankId == compareRankId); - bool hasRankMatched = tile.any(rankMatched); - if (hasRankMatched && laneInTile == 0) - { - atomicAdd_block(sharedSendRecvRankCount, 1); - } - tile.sync(); - } - __syncthreads(); - if (threadIdx.x == 0) - { - atomicAdd_system(sendRecvRankCount + blockIdx.x, *sharedSendRecvRankCount); - } -} - -template -__global__ void computeSendRecvRankCountKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount, - int* recvRankCount) -{ - static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8 - || kThreadsGroupSize == 16 || kThreadsGroupSize == 32, - "Only 1, 2, 4, 8, 16, 32 threads group size supported now."); - __shared__ int sharedSendRecvRankCount; - if (blockIdx.y == 0) - { - // compute send rank count - computeSendRecvRankCountDevice(worldInfo, expertParallelInfo, maxTokenCountPerRank, - realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, sendRankCount); - } - else - { - // compute recv rank count - computeSendRecvRankCountDevice(worldInfo, expertParallelInfo, maxTokenCountPerRank, - realRankTokenCountCumSum, gatheredTargetRankIds, &sharedSendRecvRankCount, recvRankCount); - } -} - -void computeSendRecvRankCount(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, int* sendRankCount, - int* recvRankCount, cudaStream_t stream) -{ - TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now."); - int threadsPerBlock = 1024; - auto* kernelPtr = computeSendRecvRankCountKernel<32>; - if (expertParallelInfo.topK <= 1) - { - kernelPtr = computeSendRecvRankCountKernel<1>; - } - else if (expertParallelInfo.topK <= 2) - { - kernelPtr = computeSendRecvRankCountKernel<2>; - } - else if (expertParallelInfo.topK <= 4) - { - kernelPtr = computeSendRecvRankCountKernel<4>; - } - else if (expertParallelInfo.topK <= 8) - { - kernelPtr = computeSendRecvRankCountKernel<8>; - } - else if (expertParallelInfo.topK <= 16) - { - kernelPtr = computeSendRecvRankCountKernel<16>; - } - dim3 block(worldInfo.epSize, 2, 1); - kernelPtr<<>>(worldInfo, expertParallelInfo, maxTokenCountPerRank, - realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCount, recvRankCount); -} - -template -__global__ void inplaceSendRecvRankCumSumKernel(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount) -{ - int* inputOutputPtr = blockIdx.x == 0 ? sendRankCount : recvRankCount; - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - - int tid = threadIdx.x; - int threadData = tid < worldInfo.epSize ? inputOutputPtr[tid] : 0; - - BlockScan(temp_storage).InclusiveSum(threadData, threadData); - if (tid < worldInfo.epSize) - { - inputOutputPtr[tid] = threadData; - } -} - -void inplaceSendRecvRankCumSum(MoeEpWorldInfo worldInfo, int* sendRankCount, int* recvRankCount, cudaStream_t stream) -{ - TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now."); - auto* kernelPtr = inplaceSendRecvRankCumSumKernel<1024>; - int blockSize = 1024; - if (worldInfo.epSize <= 32) - { - kernelPtr = inplaceSendRecvRankCumSumKernel<32>; - blockSize = 32; - } - else if (worldInfo.epSize <= 64) - { - kernelPtr = inplaceSendRecvRankCumSumKernel<64>; - blockSize = 64; - } - else if (worldInfo.epSize <= 128) - { - kernelPtr = inplaceSendRecvRankCumSumKernel<128>; - blockSize = 128; - } - else if (worldInfo.epSize <= 256) - { - kernelPtr = inplaceSendRecvRankCumSumKernel<256>; - blockSize = 256; - } - else if (worldInfo.epSize <= 512) - { - kernelPtr = inplaceSendRecvRankCumSumKernel<512>; - blockSize = 512; - } - kernelPtr<<<2, blockSize, 0, stream>>>(worldInfo, sendRankCount, recvRankCount); -} - -template -__inline__ __device__ void computeSendRecvIndicesDevice(MoeEpWorldInfo worldInfo, - MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, int const* realRankTokenCountCumSum, - int const* gatheredTargetRankIds, int const* sendRecvCumSum, - int* sendRecvIndices, // send or receive - int* localGatherIndices, // receive only - int* backwardRecvRankLocalIndices, // send only - int* sharedSendRecvRankStart, typename cub::BlockScan::TempStorage& tempStorage) -{ - cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); - int laneInTile = tile.thread_rank(); - int tileId = threadIdx.x / kThreadsGroupSize; - int tileCountPerBlock = blockDim.x / kThreadsGroupSize; - - int topK = expertParallelInfo.topK; - int epRank = worldInfo.epRank; - int epSize = worldInfo.epSize; - - if (threadIdx.x == 0) - { - *sharedSendRecvRankStart = blockIdx.x == 0 ? 0 : sendRecvCumSum[blockIdx.x - 1]; - } - - __syncthreads(); - int readRank = isSend ? epRank : blockIdx.x; - int compareRankId = isSend ? blockIdx.x : epRank; - int readRankStart = readRank * maxTokenCountPerRank; - int const* readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK; - int readRankTokenCount = maxTokenCountPerRank; - if (realRankTokenCountCumSum != nullptr) - { - readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1]; - readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK; - readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart; - } - - for (int blockStartId = blockIdx.z * tileCountPerBlock; blockStartId < readRankTokenCount; - blockStartId += tileCountPerBlock * gridDim.z) - { - int stepStartIndice = *sharedSendRecvRankStart; - int i = blockStartId + tileId; - int targetRankId - = (laneInTile < topK && i < readRankTokenCount) ? readRankTargetRankIds[i * topK + laneInTile] : epSize; - bool rankMatched = (targetRankId == compareRankId); - bool hasRankMatched = tile.any(rankMatched); - unsigned int laneMask = tile.ballot(rankMatched); - int lowestLane = __ffs(laneMask) - 1; - int isMatchedLane = (hasRankMatched && laneInTile == lowestLane) ? 1 : 0; - int indice; - typedef cub::BlockScan BlockScan; - BlockScan(tempStorage).ExclusiveSum(isMatchedLane, indice); - indice += stepStartIndice; - __syncthreads(); - - if (isMatchedLane == 1) - { - atomicAdd_block(sharedSendRecvRankStart, 1); - if (isSend) - { - sendRecvIndices[indice] = i; - backwardRecvRankLocalIndices[indice] = i * topK + lowestLane; - } - else - { - sendRecvIndices[indice] = indice; - localGatherIndices[indice] = readRankStart + i; - } - } - __syncthreads(); - } -} - -template -__global__ void computeSendRecvIndicesKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, - int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices, - int* recvRankLocalIndices, int* backwardRecvRankLocalIndices) -{ - static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 || kThreadsGroupSize == 8 - || kThreadsGroupSize == 16 || kThreadsGroupSize == 32, - "Only 1, 2, 4, 8, 16, 32 threads group size supported now."); - __shared__ int sharedSendRecvRankStart; - __shared__ typename cub::BlockScan::TempStorage tempStorage; - if (blockIdx.y == 0) - { - // compute send rank count - computeSendRecvIndicesDevice(worldInfo, expertParallelInfo, - maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum, - sendRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart, - tempStorage); - } - else - { - // compute recv rank count - computeSendRecvIndicesDevice(worldInfo, expertParallelInfo, - maxTokenCountPerRank, realRankTokenCountCumSum, gatheredTargetRankIds, recvRankCountCumSum, - recvRankLocalIndices, localGatherIndices, backwardRecvRankLocalIndices, &sharedSendRecvRankStart, - tempStorage); - } -} - -void computeSendRecvIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds, - int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices, - int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream) -{ - TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now."); - int threadsPerBlock = 1024; - auto* kernelPtr = computeSendRecvIndicesKernel<32, 1024>; - if (expertParallelInfo.topK <= 1) - { - kernelPtr = computeSendRecvIndicesKernel<1, 1024>; - } - else if (expertParallelInfo.topK <= 2) - { - kernelPtr = computeSendRecvIndicesKernel<2, 1024>; - } - else if (expertParallelInfo.topK <= 4) - { - kernelPtr = computeSendRecvIndicesKernel<4, 1024>; - } - else if (expertParallelInfo.topK <= 8) - { - kernelPtr = computeSendRecvIndicesKernel<8, 1024>; - } - else if (expertParallelInfo.topK <= 16) - { - kernelPtr = computeSendRecvIndicesKernel<16, 1024>; - } - else if (expertParallelInfo.topK <= 32) - { - kernelPtr = computeSendRecvIndicesKernel<32, 1024>; - } - dim3 block(worldInfo.epSize, 2, 1); - kernelPtr<<>>(worldInfo, expertParallelInfo, maxTokenCountPerRank, - realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, - sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices); -} - -__global__ void moeAllToAllMemsetKernel(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices, - int* sendRankLocalIndices, int* recvRankLocalIndices, int* backwardRecvRankLocalIndices) -{ - int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK); - int idx = threadIdx.x + blockIdx.x * blockDim.x; - int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize; - int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken; - if (idx < worldInfo.epSize) - { - sendRankCountCumSum[idx] = 0; - recvRankCountCumSum[idx] = 0; - } - if (idx < maxRankRecvTokenCount) - { - localGatherIndices[idx] = -1; - recvRankLocalIndices[idx] = -1; - } - if (idx < maxRankSendTokenCount) - { - sendRankLocalIndices[idx] = -1; - backwardRecvRankLocalIndices[idx] = -1; - } -} - -void moeAllToAllMemset(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, - int* sendRankCountCumSum, int* recvRankCountCumSum, int* localGatherIndices, int* sendRankLocalIndices, - int* recvRankLocalIndices, int* backwardRecvRankLocalIndices, cudaStream_t stream) -{ - int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK); - int maxRankRecvTokenCount = maxTokenCountPerRank * worldInfo.epSize; - int maxRankSendTokenCount = maxTokenCountPerRank * maxSendRanksPerToken; - int maxEltCount = std::max(maxRankRecvTokenCount, maxRankSendTokenCount); - maxEltCount = std::max(maxEltCount, worldInfo.epSize); - static constexpr int kBlockSize = 256; - int blockCount = (maxEltCount + kBlockSize - 1) / kBlockSize; - dim3 grid(blockCount, 1); - moeAllToAllMemsetKernel<<>>(worldInfo, expertParallelInfo, maxTokenCountPerRank, - sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices, recvRankLocalIndices, - backwardRecvRankLocalIndices); -} - -void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum, - // indices of gatheredTargetRankIds that has the local rank in topK - int* localGatherIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current - // rank - int* sendRankCountCumSum, // max length = worldInfo.epSize - int* sendRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current rank - // has maxTokenCountPerRank tokens to send and all has expertCount dest - int* recvRankCountCumSum, // max length = worldInfo.epSize - int* recvRankLocalIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks send to current - // rank - // the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum - int* - backwardRecvRankLocalIndices, // max length = maxTokenCountPerRank * expertParallelInfo.expertCount when current - // rank has maxTokenCountPerRank tokens to send and all has expertCount dest - cudaStream_t stream) -{ - moeAllToAllMemset(worldInfo, expertParallelInfo, maxTokenCountPerRank, sendRankCountCumSum, recvRankCountCumSum, - localGatherIndices, sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices, stream); - TLLM_CHECK_WITH_INFO(worldInfo.epSize <= 1024, "Only worldInfo.epSize less than or equal to 1024 supported now."); - computeSendRecvRankCount(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum, - gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, stream); - inplaceSendRecvRankCumSum(worldInfo, sendRankCountCumSum, recvRankCountCumSum, stream); - computeSendRecvIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum, - gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices, sendRankLocalIndices, - recvRankLocalIndices, backwardRecvRankLocalIndices, stream); -} - -template -__global__ void moeLocalGatherDevice(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, - int const* gatheredExpertIds, float const* gatheredScales, int* localExpertIds, float* localScales) -{ - cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); - int laneInTile = tile.thread_rank(); - int tileId = threadIdx.x / kThreadsGroupSize; - int tileCountPerBlock = blockDim.x / kThreadsGroupSize; - - int epSize = worldInfo.epSize; - int rankTokenCount = recvRankCountCumSum[epSize - 1]; - if (laneInTile >= expertParallelInfo.topK) - { - return; - } - - for (int index = tileId + blockIdx.x * tileCountPerBlock; index < localMaxTokenCount; - index += tileCountPerBlock * gridDim.x) - { - int localTokenIndice = localGatherIndices[index]; - int expertId = index < rankTokenCount - ? gatheredExpertIds[localTokenIndice * expertParallelInfo.topK + laneInTile] - : expertParallelInfo.expertCount; - localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId; - if (gatheredScales) - { - float scale = index < rankTokenCount - ? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile] - : 0.0f; - localScales[index * expertParallelInfo.topK + laneInTile] = scale; - } - } -} - -void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, - int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds, - float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream) -{ - TLLM_CHECK_WITH_INFO(expertParallelInfo.topK <= 32, "Only topK less than or equal to 32 supported now."); - auto* kernelPtr = moeLocalGatherDevice<32>; - int paddedTopK = 32; - if (expertParallelInfo.topK <= 1) - { - paddedTopK = 1; - kernelPtr = moeLocalGatherDevice<1>; - } - else if (expertParallelInfo.topK <= 2) - { - paddedTopK = 2; - kernelPtr = moeLocalGatherDevice<2>; - } - else if (expertParallelInfo.topK <= 4) - { - paddedTopK = 4; - kernelPtr = moeLocalGatherDevice<4>; - } - else if (expertParallelInfo.topK <= 8) - { - paddedTopK = 8; - kernelPtr = moeLocalGatherDevice<8>; - } - else if (expertParallelInfo.topK <= 16) - { - paddedTopK = 16; - kernelPtr = moeLocalGatherDevice<16>; - } - - int threadsPerBlock = 512; - int tokenPerBlock = threadsPerBlock / paddedTopK; - int blockCount = (localMaxTokenCount + tokenPerBlock - 1) / tokenPerBlock * 2; - - kernelPtr<<>>(worldInfo, expertParallelInfo, maxTokenCountPerRank, - localMaxTokenCount, recvRankCountCumSum, localGatherIndices, gatheredExpertIds, gatheredScales, localExpertIds, - localScales); -} - -int AllToAllChannelCommunicatorBase::maxSmCount = -1; -bool AllToAllChannelCommunicatorBase::maxSmCountUsed = false; - -void setMaxUsableSmCount(int smCount) -{ - AllToAllChannelCommunicatorBase::setMaxUsableSmCount(smCount); -} - -} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/moeCommKernels.h b/cpp/tensorrt_llm/kernels/moeCommKernels.h deleted file mode 100644 index fec9dee7bd0..00000000000 --- a/cpp/tensorrt_llm/kernels/moeCommKernels.h +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include "tensorrt_llm/common/cudaUtils.h" - -namespace tensorrt_llm::kernels -{ - -#ifdef __CUDACC__ -#define ALIGN_256 __align__(256) -#else -#define ALIGN_256 alignas(256) -#endif - -struct ALIGN_256 MoeCommFifoConnInfo -{ - volatile uint64_t head; // write position - volatile uint64_t tail; // read position -}; - -constexpr int WARP_SIZE = 32; -constexpr uint32_t WARP_MASK = 0xffffffff; - -constexpr int RECV_FIFO_DEPTH = 8; -constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024; -constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t); -constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES; -constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t); - -class AllToAllChannelCommunicatorBase -{ -public: - static constexpr int GROUP_COUNT_PER_BLOCK = 8; - static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8"); - static constexpr int WARP_PER_GROUP = 2; - static constexpr int U64_DATA_REG_PER_THREAD = 8; - // A packet is a warp-sized chunk of data that is sent or received in one go, - // but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD. - static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD; - static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t); - static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD; - static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t); - static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t); - - static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES; - - static constexpr int GROUP_MAX_INDICE_COUNT - = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD); - - struct GroupSharedBuffer - { - int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT]; - int groupStartIndice; - int groupEndIndice; - }; - - static void setMaxUsableSmCount(int maxUsableSmCount) - { - TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false, - "setMaxUsableSmCount can be called only before it is used"); - int smCount = tensorrt_llm::common::getMultiProcessorCount(); - if (maxUsableSmCount > smCount) - { - TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead", - maxUsableSmCount, smCount); - maxUsableSmCount = smCount; - } - AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount; - } - - static int getMaxUsableSmCount() - { - AllToAllChannelCommunicatorBase::maxSmCountUsed = true; - if (AllToAllChannelCommunicatorBase::maxSmCount == -1) - { - int smCount = tensorrt_llm::common::getMultiProcessorCount(); - AllToAllChannelCommunicatorBase::maxSmCount = smCount; - } - return AllToAllChannelCommunicatorBase::maxSmCount; - } - - static int computeMoeCommChannelCount(int epSize) - { - int smCount = getMaxUsableSmCount(); - int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK; - blockCountPerChannel *= 2; // for send and recv - TLLM_CHECK_WITH_INFO( - blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount); - int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication - int channelCount = std::max(perferredChannel, 1); // at lease one channel - return channelCount; - } - - static int getMoeCommChannelCount(int epSize) - { - static std::map channelCountMap{}; - auto iter = channelCountMap.find(epSize); - if (iter == channelCountMap.end()) - { - auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize); - channelCountMap[epSize] = channelCount; - return channelCount; - } - return iter->second; - } - - static dim3 getLaunchBlockDim() - { - return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK); - } - - static dim3 getLaunchGridDim(int epSize) - { - int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize); - return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2); - } - -protected: - static int maxSmCount; - static bool maxSmCountUsed; -}; - -inline size_t getMoeCommWorkspaceSize(int epSize) -{ - int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize); - return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount; -} - -struct MoeEpWorldInfo -{ - int epSize; - int epRank; -}; - -struct MoeExpertParallelInfo -{ - int expertCount = -1; - int topK = 1; -}; - -struct SendRecvDataInfo -{ - int vectorSizeInU64; - // pre-computed at host side for GPU kernel - int dataPacketCountPerVector; - int vectorCountPerFifoEntry; - - void ComputeDataPacketCountPerVector() - { - dataPacketCountPerVector - = (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1) - / AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET; - } - - void ComputeVectorCountPerFifoEntry() - { - ComputeDataPacketCountPerVector(); - vectorCountPerFifoEntry - = AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector; - } - - void DoPreCompute() - { - ComputeDataPacketCountPerVector(); - ComputeVectorCountPerFifoEntry(); - assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT); - } -}; - -// struct holding Send/Recv data pointer and its displacement information. -struct SendRecvDispls -{ - uint64_t* dataPtr; - int const* rankCountCumSum; // length = epSize - int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else - // rankCountCumSum[epRank] - int vectorStrideInU64; - -#ifdef __CUDACC__ - __inline__ __device__ int getCount(int rank) const - { - return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1]; - } - - __inline__ __device__ int getRankStart(int rank) const - { - return rank == 0 ? 0 : rankCountCumSum[rank - 1]; - } - - __inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const - { - return rankLocalIndices[globalVectorIndex]; - } - - __inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const - { - return dataPtr + realVectorIndex * vectorStrideInU64; - } -#endif -}; - -struct MoeCommWorkspace -{ - uint64_t* workspacePtr; - size_t rankStrideInU64; -#ifdef __CUDACC__ - __inline__ __device__ uint64_t* getFifoBasePtr( - bool isSender, int epRank, int peerRank, int channel, int channelCount) const - { - // fifo itself is in receiver's side. - if (isSender) - { - return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64; - } - else - { - return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64; - } - } - - __inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo( - bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const - { - // fifoInfo is in sender's side. - uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize; - int strideIndice = isSender ? epRank : peerRank; - int fifoInfoIndice = isSender ? peerRank : epRank; - fifoInfoPtrU64 += strideIndice * rankStrideInU64; - MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64; - return fifoInfoPtr + fifoInfoIndice * channelCount + channel; - } -#endif -}; - -void setMaxUsableSmCount(int smCount); - -void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls, - SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream); - -void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, - int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum, - int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK - int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices, - // the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum - int* backwardRecvRankLocalIndices, cudaStream_t stream); - -void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank, - int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds, - float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream); - -} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/moeCommKernelsCommon.h b/cpp/tensorrt_llm/kernels/moeCommKernelsCommon.h new file mode 100644 index 00000000000..7d4310764b1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/moeCommKernelsCommon.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace tensorrt_llm +{ +namespace kernels +{ + +#ifdef __CUDACC__ +#define ALIGN_256 __align__(256) +#else +#define ALIGN_256 alignas(256) +#endif + +constexpr int WARP_SIZE = 32; +constexpr uint32_t WARP_MASK = 0xffffffff; + +struct MoeEpWorldInfo +{ + int epSize; + int epRank; +}; + +struct MoeExpertParallelInfo +{ + int expertCount = -1; + int topK = 1; +}; + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/moePrepareKernels.cu b/cpp/tensorrt_llm/kernels/moePrepareKernels.cu index 6ca40a948aa..aea271dab58 100644 --- a/cpp/tensorrt_llm/kernels/moePrepareKernels.cu +++ b/cpp/tensorrt_llm/kernels/moePrepareKernels.cu @@ -49,86 +49,6 @@ __device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr) return ret; } -class StepCommunicatorBase -{ -public: - static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo); - - __device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo) - : fifoConnInfo(fifoConnInfo) - , localCachedHead(0) - , localCachedTail(0) - { - } - - __forceinline__ __device__ void reset() - { - fifoConnInfo->head = 0; - fifoConnInfo->tail = 0; - } - - __forceinline__ __device__ void releaseSendStep() - { - localCachedHead += 1; - st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead)); - } - - __forceinline__ __device__ void releaseRecvStep() - { - localCachedTail += 1; - st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail)); - } - - __forceinline__ __device__ uint64_t acquireTail() - { - uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail)); - localCachedTail = tail; - return tail; - } - - __forceinline__ __device__ uint64_t acquireHead() - { - uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head)); - localCachedHead = head; - return head; - } - - __forceinline__ __device__ int acquireNewSendStep() - { - - int64_t tail; - do - { - tail = acquireTail(); - } while (localCachedHead >= tail + STEP_DEPTH); - // depth = 2, head = 1, tail = 0 , ok - // depth = 2, head = 2, tail = 0, should wait - - return localCachedHead % STEP_DEPTH; - } - - __forceinline__ __device__ int acquireNewRecvStep() - { - int64_t head = 0; - do - { - head = acquireHead(); - } while (localCachedTail >= head); - - return localCachedTail % STEP_DEPTH; - } - -public: - MoeCommFifoConnInfo* fifoConnInfo; - uint64_t localCachedHead; - uint64_t localCachedTail; - int rank; - int targetRank; -}; - -// Use MoeCommFifoConnInfo as media to transfer a counter number. -// Use the "head" field as flag. -// Use the "tail" field to transfer the counter number. class CounterCommunicator { public: @@ -137,23 +57,23 @@ public: { } - __forceinline__ __device__ void releaseValue(uint64_t value) + __forceinline__ __device__ void releaseValue(uint64_t value, int index) { // Avoid block on 0 - st_release_sys_global(&(fifoConnInfo->count), value + 1); + fifoConnInfo->values[index] = value + 1; } - __forceinline__ __device__ uint64_t acquireValue() + __forceinline__ __device__ uint64_t acquireValue(int index) { - uint64_t localCount = 0; + uint64_t localValue = 0; do { - localCount = ld_acquire_sys_global(&(fifoConnInfo->count)); - } while (localCount == 0); + localValue = fifoConnInfo->values[index]; + } while (localValue == 0); - fifoConnInfo->count = 0; // reset the count + fifoConnInfo->values[index] = 0; // reset the value - return localCount - 1; + return localValue - 1; } protected: @@ -161,15 +81,16 @@ protected: }; template -__device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount, int* sharedSendRecvRankCount, - int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace, - int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize) +__device__ __forceinline__ void computeCountAndSendStatics(int* experts, int tokenCount, int* sharedSendRecvRankCount, + int* sendCounts, int* sendIndiceWorkspace, int* backwardIndiceWorkspace, int* expertStatics, + MoeCommWorkspace workspace, int maxTokenCountPerRank, int slotCount, int expertCount, int topK, int epRank, + int epSize) { cg::thread_block_tile tile = cg::tiled_partition(cg::this_thread_block()); int laneInTile = tile.thread_rank(); int tileId = threadIdx.x / kThreadsGroupSize; int tileCountPerBlock = blockDim.x / kThreadsGroupSize; - int expertCountPerRank = expertCount / epSize; + int expertCountPerRank = slotCount / epSize; if (threadIdx.x == 0) { *sharedSendRecvRankCount = 0; @@ -201,18 +122,24 @@ __device__ __forceinline__ void computeCountAndSend(int* experts, int tokenCount tile.sync(); } __syncthreads(); - if (threadIdx.x == 0) + + CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1)); + + int communicationCount = expertStatics == nullptr ? 1 : expertCount + 1; + for (int i = threadIdx.x; i < communicationCount; i += blockDim.x) { - CounterCommunicator counter(workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1)); - int count = *(sharedSendRecvRankCount); - // printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId); - counter.releaseValue(uint64_t(count)); - *(sendCounts + targetRankId) = count; + int value = i == 0 ? *(sharedSendRecvRankCount) : *(expertStatics + i - 1); + counter.releaseValue(value, i); + if (i == 0) + { + *(sendCounts + targetRankId) = value; + } } } -__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase, - MoeCommWorkspace workspace, int maxTokenCountPerRank, int rankId, int rankCount) +__device__ __forceinline__ void recvCountAndStatics(int* recvIndiceWorkspace, int* recvCounts, int* sharedCountsBase, + int* gatheredExpertStatics, MoeCommWorkspace workspace, int expertCount, int maxTokenCountPerRank, int rankId, + int rankCount) { int rankOffset = threadIdx.x / THREADS_PER_PIPELINE; if (rankOffset >= PIPELINE_PER_CTA) @@ -229,18 +156,25 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou cg::thread_block_tile rankTile = cg::tiled_partition(cg::this_thread_block()); int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank; - int rankRecvCount; - if (rankTile.thread_rank() == 0) + + CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1)); + int communicationCount = gatheredExpertStatics == nullptr ? 1 : expertCount + 1; + for (int i = rankTile.thread_rank(); i < communicationCount; i += THREADS_PER_PIPELINE) { - CounterCommunicator counter(workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1)); - rankRecvCount = int(counter.acquireValue()); - // printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId, targetRankId); - *(recvCounts + targetRankId) = rankRecvCount; - *(sharedCountsThisRank) = rankRecvCount; + int recvValue = counter.acquireValue(i); + if (i == 0) + { + *(recvCounts + targetRankId) = recvValue; + *(sharedCountsThisRank) = recvValue; + } + else + { + *(gatheredExpertStatics + targetRankId * expertCount + i - 1) = recvValue; + } } rankTile.sync(); - rankRecvCount = *(sharedCountsThisRank); + int rankRecvCount = *(sharedCountsThisRank); for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE) { *(localRecvIndice + tokenId) = tokenId; @@ -249,20 +183,22 @@ __device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCou template __global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expertCount, int rankId, int rankCount) + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics, + MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount, + int rankId, int rankCount) { __shared__ int sharedCounts[PIPELINE_PER_CTA]; bool isSender = blockIdx.x < rankCount; if (isSender) { - computeCountAndSend(experts, tokenCount, &sharedCounts[0], sendCounts, sendIndiceWorkspace, - backwardIndiceWorkspace, workspace, maxTokenCountPerRank, expertCount, topK, rankId, rankCount); + computeCountAndSendStatics(experts, tokenCount, &sharedCounts[0], sendCounts, + sendIndiceWorkspace, backwardIndiceWorkspace, expertStatics, workspace, maxTokenCountPerRank, slotCount, + expertCount, topK, rankId, rankCount); } else { - recvCount( - recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank, rankId, rankCount); + recvCountAndStatics(recvIndiceWorkspace, recvCounts, &sharedCounts[0], gatheredExpertStatics, workspace, + expertCount, maxTokenCountPerRank, rankId, rankCount); } } @@ -307,259 +243,12 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum int tid = threadIdx.x; int threadData = tid < rankCount ? inputOutputPtr[tid] : 0; - int count = threadData; __syncthreads(); BlockScan(temp_storage).InclusiveSum(threadData, threadData); if (tid < rankCount) { inputOutputPtr[tid] = threadData; - // printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0, rankId, tid, - // threadData, count); - } -} - -template -class PacketPipeline -{ -public: - __device__ __inline__ PacketPipeline( - void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender) - : bufferBase(bufferBase) - , stepCommunicator(stepCommunicator) - , shared_new_step(sharedNewStepPtr) - { - step = 0; - needRelease = false; - packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1; - } - - __device__ __forceinline__ void* getFirstSendPacket() - { - return bufferBase; - } - - __device__ __inline__ void* finishSendPacket(bool acquireNewStep) - { - - packetId++; - if (packetId < PipelineConfig::PACKET_PER_STEP) - { - return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE - + packetId * PipelineConfig::PACKET_SIZE - : nullptr; - } - - __syncthreads(); - if (threadIdx.x == 0) - { - stepCommunicator->releaseSendStep(); - if (acquireNewStep) - { - step = stepCommunicator->acquireNewSendStep(); - *(shared_new_step) = step; - } - } - __syncthreads(); - - if (acquireNewStep) - { - step = *(shared_new_step); - packetId = 0; - return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; - } - - return nullptr; - } - - __device__ __forceinline__ void* sendFinalize() - { - if (packetId > 0 && threadIdx.x == 0) - { - stepCommunicator->releaseSendStep(); - } - } - - __device__ __inline__ void* getNewRecvPacket() - { - packetId++; - if (packetId < PipelineConfig::PACKET_PER_STEP) - { - return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE - + packetId * PipelineConfig::PACKET_SIZE; - } - - __syncthreads(); - if (threadIdx.x == 0) - { - if (needRelease) - { - stepCommunicator->releaseRecvStep(); - } - step = stepCommunicator->acquireNewRecvStep(); - needRelease = true; - *(shared_new_step) = step; - } - __syncthreads(); - packetId = 0; - step = *(shared_new_step); - void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP; - - return packetPtr; - } - - __device__ __forceinline__ void reset() - { - if (threadIdx.x == 0) - { - stepCommunicator->reset(); - } - } - - void* bufferBase; - StepCommunicatorBase* stepCommunicator; - int step; - int packetId; - bool needRelease; - int* shared_new_step; -}; - -template -__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, - int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, - int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, - int topK, int expertCount, int slotCount, int rankId, int rankCount) -{ - bool isSender = (blockIdx.y == 0); - int targetRankId = blockIdx.x; - int slotCountPerRank = slotCount / rankCount; - int groupSize = topK / PipelineConfig::UNIT_SIZE; - - __shared__ int sharedNewStep; - __align__(16) int experts[PipelineConfig::UNIT_SIZE]; - __align__(16) float scales[PipelineConfig::UNIT_SIZE]; - - uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1)); - StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1)); - PacketPipeline pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender); - - if (isSender) - { - int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1); - int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum; - int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE; - - void* packPtr = pipeline.getFirstSendPacket(); - int indexBase = 0; - int staticCopyBase = 0; - bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0); - while (acquireNewStep) - { - if (threadIdx.x < UNIT_PER_ITER) - { - int index = indexBase + threadIdx.x; - int groupId = index % groupSize; - if (index < unitCount) - { - int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize)); - *((ExpertType*) (experts)) - = *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - -#pragma unroll - for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++) - { - int expertId = experts[j]; - if (expertId / slotCountPerRank != targetRankId) - { - experts[j] = slotCount; - } - } - - int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts)); - if (sendScales != nullptr) - { - *((ScaleType*) (scales)) - = *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET); - float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales)); - } - } - } - else if (localExpertStatics != nullptr) - { - int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; - if (staticCopyBase + staticCopyIdx * 4 < expertCount) - { - int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET); - int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4); - *(staticBasePtr + staticCopyIdx) = staticData; - } - } - - indexBase += UNIT_PER_ITER; - staticCopyBase += STATIC_COPY_PER_ITER * 4; - acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount; - packPtr = pipeline.finishSendPacket(acquireNewStep); - } - - pipeline.sendFinalize(); - } - else - { - int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1); - int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum; - int recvUnitCount = recvTokenCount * groupSize; - - int unitIdBase = 0; - int staticCopyBase = 0; - while (unitIdBase < recvUnitCount || (localExpertStatics != nullptr && staticCopyBase < expertCount)) - { - void* packetPtr = pipeline.getNewRecvPacket(); - int packetUnitCount - = unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase; - packetUnitCount = max(packetUnitCount, 0); - if (threadIdx.x < UNIT_PER_ITER) - { - if (threadIdx.x < packetUnitCount) - { - int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize; - int groupId = (unitIdBase + threadIdx.x) % groupSize; - int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr)); - ExpertType* dstExpertsPtr - = (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - *dstExpertsPtr = *((ExpertType*) (experts)); - - if (recvScales != nullptr) - { - float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET); - float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE; - *((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr)); - ScaleType* dstScalesPtr - = (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE); - *dstScalesPtr = *((ScaleType*) (scales)); - } - } - } - else if (localExpertStatics != nullptr) - { - int staticCopyIdx = threadIdx.x - UNIT_PER_ITER; - if (staticCopyBase + staticCopyIdx * 4 < expertCount) - { - int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET); - int4 staticData = *(staticBasePtr + staticCopyIdx); - *(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4) - = staticData; - } - } - - unitIdBase += packetUnitCount; - staticCopyBase += STATIC_COPY_PER_ITER * 4; - } - - pipeline.reset(); } } @@ -576,8 +265,9 @@ __global__ void memsetExpertIdsDevice( } void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream) + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics, + MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount, + int rankId, int rankCount, cudaStream_t stream) { // first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for receive int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA; @@ -607,7 +297,8 @@ void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* kernelFn = computeCountAndIndiceDevice<2>; } kernelFn<<>>(experts, sendCounts, recvCounts, sendIndiceWorkspace, backwardIndiceWorkspace, - recvIndiceWorkspace, workspace, tokenCount, maxTokenCountPerRank, topK, expert_count, rankId, rankCount); + recvIndiceWorkspace, expertStatics, gatheredExpertStatics, workspace, tokenCount, maxTokenCountPerRank, topK, + slotCount, expertCount, rankId, rankCount); } void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream) @@ -628,46 +319,18 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i backwardIndice, gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank); } -void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, - int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, - int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, - int slotCount, int rankId, int rankCount, cudaStream_t stream) +void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, + int rankCount, cudaStream_t stream) { - int block_size = localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER; - dim3 block(block_size); - dim3 grid(rankCount, 2); - - if (topK % 4 == 0) - { - using PipelineConfig = PipelineConfig<4, 16>; - static_assert( - PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, - "FIFO size is too small"); - allToAllMetadataDevice<<>>(sendExperts, recvExperts, - sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, - localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, - slotCount, rankId, rankCount); - } - else - { - using PipelineConfig = PipelineConfig<1, 64>; - static_assert( - PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64, - "FIFO size is too small"); - allToAllMetadataDevice<<>>(sendExperts, recvExperts, - sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, - localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, - slotCount, rankId, rankCount); - } - int smCount = tensorrt_llm::common::getMultiProcessorCount(); - memsetExpertIdsDevice<<>>( - recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount); + int block_size = 256; + memsetExpertIdsDevice<<>>( + expertIds, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount); } size_t getMoePrepareWorkspaceSize(int epSize) { - return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize; + return sizeof(MoeCommFifoConnInfo) * epSize; } } // namespace moe_prepare diff --git a/cpp/tensorrt_llm/kernels/moePrepareKernels.h b/cpp/tensorrt_llm/kernels/moePrepareKernels.h index 0635397970f..7fbb7be6bd9 100644 --- a/cpp/tensorrt_llm/kernels/moePrepareKernels.h +++ b/cpp/tensorrt_llm/kernels/moePrepareKernels.h @@ -28,36 +28,11 @@ namespace tensorrt_llm::kernels namespace moe_prepare { -#define STEP_DEPTH 2 -#define THREADS_PER_UNIT 1 #define UNIT_PER_PIPELINE 128 #define PIPELINE_PER_CTA 4 -#define EXPERT_BYTES_PER_UNIT 32 -#define SCALE_BYTES_PER_UNIT 32 -#define UNIT_COUNT_PER_PACKET 1024 -#define BYTES_COUNTER 8 #define CUMSUM_THREADS_PER_BLOCK 128 -#define UNIT_PER_ITER 256 -#define STATIC_COPY_PER_ITER 128 - -static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE; -static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA; - -template -struct PipelineConfig -{ - static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT; - static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT; - static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); - static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int); - static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)); - static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int); - static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8); -}; - -// 1MB FIFO size -static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8; +static constexpr int THREADS_PER_PIPELINE = UNIT_PER_PIPELINE; #ifdef __CUDACC__ #define ALIGN_256 __align__(256) @@ -67,9 +42,9 @@ static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8; struct ALIGN_256 MoeCommFifoConnInfo { - volatile uint64_t head; // write position - volatile uint64_t tail; // read position - volatile uint64_t count; // for counter + volatile uint64_t head; // write position + volatile uint64_t tail; // read position + int volatile values[512]; // for values }; struct MoeCommWorkspace @@ -77,25 +52,11 @@ struct MoeCommWorkspace uint64_t* workspacePtr; size_t rankStrideInU64; #ifdef __CUDACC__ - __inline__ __device__ uint64_t* getFifoBasePtr( - bool isSender, int epRank, int peerRank, int channel, int channelCount) const - { - // fifo itself is in receiver's side. - if (isSender) - { - return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * FIFO_SIZE_IN_U64; - } - else - { - return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * FIFO_SIZE_IN_U64; - } - } - __inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo( bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const { // fifoInfo is in sender's side. - uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize; + uint64_t* fifoInfoPtrU64 = workspacePtr; int strideIndice = isSender ? epRank : peerRank; int fifoInfoIndice = isSender ? peerRank : epRank; fifoInfoPtrU64 += strideIndice * rankStrideInU64; @@ -108,8 +69,9 @@ struct MoeCommWorkspace }; void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace, - int* backwardIndiceWorkspace, int* recvIndiceWorkspace, MoeCommWorkspace workspace, int tokenCount, - int maxTokenCountPerRank, int topK, int expert_count, int rankId, int rankCount, cudaStream_t stream); + int* backwardIndiceWorkspace, int* recvIndiceWorkspace, int* expertStatics, int* gatheredExpertStatics, + MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank, int topK, int slotCount, int expertCount, + int rankId, int rankCount, cudaStream_t stream); void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount, cudaStream_t stream); @@ -117,10 +79,8 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount, int maxTokenCountPerRank, cudaStream_t stream); -void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales, int* localExpertStatics, - int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice, - int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank, int topK, int expertCount, - int slotCount, int rankId, int rankCount, cudaStream_t stream); +void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, + int epSize, cudaStream_t stream); size_t getMoePrepareWorkspaceSize(int epSize); diff --git a/cpp/tensorrt_llm/thop/moeCommOp.cpp b/cpp/tensorrt_llm/thop/moeCommOp.cpp index 2d21db13b94..3f76f2e876b 100644 --- a/cpp/tensorrt_llm/thop/moeCommOp.cpp +++ b/cpp/tensorrt_llm/thop/moeCommOp.cpp @@ -16,7 +16,7 @@ */ #include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/kernels/moeCommKernels.h" +#include "tensorrt_llm/kernels/fusedMoeCommKernels.h" #include "tensorrt_llm/kernels/moePrepareKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" @@ -28,180 +28,104 @@ namespace torch_ext { -std::tuple -moeCommPrepareIndicesOp(torch::Tensor gatheredTargetRankIds, c10::optional realRankTokenCountCumSum, - int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) +void setMoeCommFieldInfo(tensorrt_llm::kernels::MoeCommFieldInfo& fieldInfo, torch::Tensor const& tensor) { - CHECK_INPUT(gatheredTargetRankIds, torch::kInt32); - TORCH_CHECK(gatheredTargetRankIds.dim() == 2, "gatheredTargetRankIds must be a 2D tensor"); - TORCH_CHECK(gatheredTargetRankIds.size(1) == topK, "gatheredTargetRankIds must have topK columns"); - - int const* realRankTokenCountCumSumPtr = nullptr; - if (realRankTokenCountCumSum.has_value()) - { - TORCH_CHECK(realRankTokenCountCumSum.value().dim() == 1, "realRankTokenCountCumSum must be a 1D tensor"); - TORCH_CHECK(realRankTokenCountCumSum.value().dtype() == torch::kInt32, - "realRankTokenCountCumSum must be a int32 tensor"); - TORCH_CHECK( - realRankTokenCountCumSum.value().size(0) == epSize, "realRankTokenCountCumSum must have epSize elements"); - realRankTokenCountCumSumPtr = realRankTokenCountCumSum.value().data_ptr(); - } - else - { - TORCH_CHECK(gatheredTargetRankIds.size(0) == epSize * maxTokenCountPerRank, - "gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)"); - } - TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0"); - TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0"); - TORCH_CHECK(topK > 0, "topK must be greater than 0"); - TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount"); - TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); - - auto stream = at::cuda::getCurrentCUDAStream(); - - int maxSendRanksPerToken = std::max(epSize, topK); - - torch::Tensor localGatherIndices - = torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - torch::Tensor sendRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - torch::Tensor sendRankLocalIndices = torch::empty( - {maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - torch::Tensor recvRankCountCumSum = torch::empty({epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - torch::Tensor recvRankLocalIndices - = torch::empty({maxTokenCountPerRank * epSize}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - torch::Tensor backwardRecvRankLocalIndices = torch::empty( - {maxTokenCountPerRank * maxSendRanksPerToken}, gatheredTargetRankIds.options().dtype(torch::kInt32)); - - tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo; - expertParallelInfo.expertCount = expertCount; - expertParallelInfo.topK = topK; - - tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; - tensorrt_llm::kernels::moeAllToAllPrepareIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank, - gatheredTargetRankIds.data_ptr(), realRankTokenCountCumSumPtr, localGatherIndices.data_ptr(), - sendRankCountCumSum.data_ptr(), sendRankLocalIndices.data_ptr(), recvRankCountCumSum.data_ptr(), - recvRankLocalIndices.data_ptr(), backwardRecvRankLocalIndices.data_ptr(), stream); - - return std::make_tuple(localGatherIndices, sendRankCountCumSum, sendRankLocalIndices, recvRankCountCumSum, - recvRankLocalIndices, backwardRecvRankLocalIndices); + TORCH_CHECK(tensor.dim() == 2, "tensor must be a 2D tensor"); + int eltSize = tensor.dtype().itemsize(); + fieldInfo.fillFieldInfo(static_cast(tensor.data_ptr()), eltSize, tensor.size(1), tensor.stride(0)); } -void moeLocalGatherOp(torch::Tensor recvRankCumSum, torch::Tensor localGatherIndices, torch::Tensor gatheredExpertIds, - c10::optional gatheredScales, torch::Tensor localExpertIds, c10::optional localScales, - int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) -{ - CHECK_INPUT(recvRankCumSum, torch::kInt32); - CHECK_INPUT(localGatherIndices, torch::kInt32); - CHECK_INPUT(gatheredExpertIds, torch::kInt32); - CHECK_INPUT(localExpertIds, torch::kInt32); - - TORCH_CHECK(maxTokenCountPerRank > 0, "maxTokenCountPerRank must be greater than 0"); - TORCH_CHECK(expertCount > 0, "expertCount must be greater than 0"); - TORCH_CHECK(topK > 0, "topK must be greater than 0"); - TORCH_CHECK(topK <= expertCount, "topK must be less than or equal to expertCount"); - TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); - - TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor"); - TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements"); - TORCH_CHECK(localGatherIndices.dim() == 1, "localGatherIndices must be a 1D tensor"); - TORCH_CHECK(gatheredExpertIds.dim() == 2, "gatheredExpertIds must be a 2D tensor"); - TORCH_CHECK(localExpertIds.dim() == 2, "localExpertIds must be a 2D tensor"); - TORCH_CHECK(gatheredExpertIds.size(1) == topK, "gatheredExpertIds must have topK columns"); - TORCH_CHECK(localExpertIds.size(1) == topK, "localExpertIds must have topK columns"); - - int localMaxTokenCount = static_cast(localGatherIndices.size(0)); - TORCH_CHECK(localExpertIds.size(0) == localMaxTokenCount, "localExpertIds must have localMaxTokenCount rows"); - - TORCH_CHECK(gatheredScales.has_value() == localScales.has_value(), - "gatheredScales and localScales must be both valid or both invalid"); - float const* gatheredScalesPtr = nullptr; - float* localScalesPtr = nullptr; - if (gatheredScales.has_value()) - { - CHECK_INPUT(gatheredScales.value(), torch::kFloat32); - CHECK_INPUT(localScales.value(), torch::kFloat32); - - TORCH_CHECK(gatheredScales->dim() == 2, "gatheredScales must be a 2D tensor"); - TORCH_CHECK(gatheredScales->size(1) == topK, "gatheredScales must have topK columns"); - TORCH_CHECK(localScales->dim() == 2, "localScales must be a 2D tensor"); - TORCH_CHECK(localScales->size(1) == topK, "localScales must have topK columns"); - TORCH_CHECK(localScales->size(0) == localMaxTokenCount, "localScales must have localMaxTokenCount rows"); - - gatheredScalesPtr = gatheredScales->data_ptr(); - localScalesPtr = localScales->data_ptr(); - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - tensorrt_llm::kernels::MoeExpertParallelInfo expertParallelInfo; - expertParallelInfo.expertCount = expertCount; - expertParallelInfo.topK = topK; - - tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; - tensorrt_llm::kernels::moeLocalGather(worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount, - recvRankCumSum.data_ptr(), localGatherIndices.data_ptr(), gatheredExpertIds.data_ptr(), - gatheredScalesPtr, localExpertIds.data_ptr(), localScalesPtr, stream); -} - -void moeCommOp(torch::Tensor input, torch::Tensor sendRankCumSum, torch::Tensor sendIndices, torch::Tensor output, - torch::Tensor recvRankCumSum, torch::Tensor recvIndices, torch::Tensor allWorkspaces, int64_t epRank, - int64_t epSize) +c10::List moeCommOp(c10::List inputs, torch::Tensor sendRankCumSum, + torch::Tensor sendIndiceTensor, torch::Tensor recvRankCumSum, torch::Tensor recvIndiceTensor, + torch::Tensor allWorkspaces, int64_t outputAllocationCount, int64_t epRank, int64_t epSize, + std::optional> needZeroOutput = std::nullopt) { CHECK_INPUT(sendRankCumSum, torch::kInt32); - CHECK_INPUT(sendIndices, torch::kInt32); + CHECK_INPUT(sendIndiceTensor, torch::kInt32); CHECK_INPUT(recvRankCumSum, torch::kInt32); - CHECK_INPUT(recvIndices, torch::kInt32); - // allWorkspaces is a uint64 tensor, but may not be contiguous - TORCH_CHECK(allWorkspaces.dtype() == torch::kUInt64, "allWorkspaces must be a uint64 tensor"); + CHECK_INPUT(recvIndiceTensor, torch::kInt32); - TORCH_CHECK(input.dim() == 2, "input must be a 2D tensor"); - TORCH_CHECK(output.dim() == 2, "output must be a 2D tensor"); TORCH_CHECK(sendRankCumSum.dim() == 1, "sendRankCumSum must be a 1D tensor"); - TORCH_CHECK(sendIndices.dim() == 1, "sendIndices must be a 1D tensor"); + TORCH_CHECK(sendIndiceTensor.dim() == 1, "sendIndices must be a 1D tensor"); TORCH_CHECK(recvRankCumSum.dim() == 1, "recvRankCumSum must be a 1D tensor"); - TORCH_CHECK(recvIndices.dim() == 1, "recvIndices must be a 1D tensor"); + TORCH_CHECK(recvIndiceTensor.dim() == 1, "recvIndices must be a 1D tensor"); TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor"); - TORCH_CHECK(input.size(1) == output.size(1), "input and output must have the same second dimension"); TORCH_CHECK(sendRankCumSum.size(0) == epSize, "sendRankCumSum must have epSize elements"); TORCH_CHECK(recvRankCumSum.size(0) == epSize, "recvRankCumSum must have epSize elements"); TORCH_CHECK(allWorkspaces.size(0) == epSize, "allWorkspaces must have epSize elements"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); + TORCH_CHECK(!needZeroOutput.has_value() || needZeroOutput.value().size() == inputs.size(), + "needZeroOutput should have same length as inputs"); + c10::List outputs; + + tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast(epSize), static_cast(epRank)}; + tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo}; + + tensorrt_llm::kernels::SendRecvIndices sendIndices, recvIndices; + sendIndices.rankCountCumSum = sendRankCumSum.data_ptr(); + sendIndices.rankLocalIndices = sendIndiceTensor.data_ptr(); + + recvIndices.rankCountCumSum = recvRankCumSum.data_ptr(); + recvIndices.rankLocalIndices = recvIndiceTensor.data_ptr(); + + int fieldCount = inputs.size(); + TORCH_CHECK(fieldCount <= tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, "Number of fields (", fieldCount, + ") exceeds maximum allowed (", tensorrt_llm::kernels::MOE_COMM_FIELD_MAX_COUNT, ")"); + tensorrt_llm::kernels::FusedMoeFieldInfo sendFieldInfo, recvFieldInfo; + sendFieldInfo.isBasicInterleaved = false; + recvFieldInfo.isBasicInterleaved = false; + sendFieldInfo.fieldCount = fieldCount; + recvFieldInfo.fieldCount = fieldCount; + sendFieldInfo.expertScales = nullptr; + recvFieldInfo.expertScales = nullptr; + sendFieldInfo.tokenSelectedSlots = nullptr; + recvFieldInfo.tokenSelectedSlots = nullptr; + + for (int i = 0; i < fieldCount; i++) + { + torch::Tensor const& t = inputs[i]; + setMoeCommFieldInfo(sendFieldInfo.fieldsInfo[i], t); + if (needZeroOutput.has_value() && needZeroOutput.value()[i]) + { + outputs.push_back(torch::zeros({outputAllocationCount, t.size(1)}, t.options())); + } + else + { + outputs.push_back(torch::empty({outputAllocationCount, t.size(1)}, t.options())); + } + setMoeCommFieldInfo(recvFieldInfo.fieldsInfo[i], outputs[i]); + } + sendFieldInfo.fillFieldPlacementInfo(0, false); + recvFieldInfo.fillFieldPlacementInfo(0, false); - tensorrt_llm::kernels::MoeEpWorldInfo worldInfo = {static_cast(epSize), static_cast(epRank)}; - tensorrt_llm::kernels::SendRecvDataInfo sendRecvDataInfo; - - size_t eltSize = input.dtype().itemsize(); - size_t eltCountPerU64 = sizeof(uint64_t) / eltSize; - TORCH_CHECK(input.size(1) % (eltCountPerU64 * 2) == 0, "input.size(1) must be aligned to 16 bytes"); - sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64; - sendRecvDataInfo.DoPreCompute(); - - tensorrt_llm::kernels::SendRecvDispls sendDispls, recvDispls; - sendDispls.dataPtr = static_cast(input.data_ptr()); - sendDispls.rankCountCumSum = sendRankCumSum.data_ptr(); - sendDispls.rankLocalIndices = sendIndices.data_ptr(); - sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64; + tensorrt_llm::kernels::FusedMoeCommKernelParam params; + params.worldInfo = worldInfo; + params.sendIndices = sendIndices; + params.recvIndices = recvIndices; + params.sendFieldInfo = sendFieldInfo; + params.recvFieldInfo = recvFieldInfo; + // Do not need expertParallelInfo for fused moe comm now - recvDispls.dataPtr = static_cast(output.data_ptr()); - recvDispls.rankCountCumSum = recvRankCumSum.data_ptr(); - recvDispls.rankLocalIndices = recvIndices.data_ptr(); - recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64; + params.sendFieldInfo.fillMetaInfo(&(params.sendCommMeta), params.expertParallelInfo.topK, false, false); + params.recvFieldInfo.fillMetaInfo(&(params.recvCommMeta), params.expertParallelInfo.topK, false, false); - tensorrt_llm::kernels::MoeCommWorkspace workspace; - workspace.workspacePtr = allWorkspaces.data_ptr(); - workspace.rankStrideInU64 = allWorkspaces.stride(0); + tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace; + tensorrt_llm::kernels::constructWorkspace( + &fusedMoeWorkspace, allWorkspaces.data_ptr(), allWorkspaces.stride(0), epSize); auto stream = at::cuda::getCurrentCUDAStream(); - tensorrt_llm::kernels::moeAllToAll(worldInfo, sendRecvDataInfo, sendDispls, recvDispls, workspace, stream); + tensorrt_llm::kernels::moeAllToAll(params, fusedMoeWorkspace, stream); + + return outputs; } int64_t getWorkspaceSizePerRank(int64_t epSize) { int epSize32 = static_cast(epSize); - return tensorrt_llm::kernels::getMoeCommWorkspaceSize(epSize32); + return tensorrt_llm::kernels::getFusedMoeCommWorkspaceSize(epSize32); } void setMaxUsableSmCount(int64_t maxSmCount) @@ -215,15 +139,29 @@ int64_t getPrepareWorkspaceSizePerRank(int64_t epSize) return tensorrt_llm::kernels::moe_prepare::getMoePrepareWorkspaceSize(epSize32); } -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, - torch::Tensor, c10::optional> -moePrepareOp(torch::Tensor expertsIds, c10::optional scales, c10::optional expertsStatics, - torch::Tensor allWorkspaces, int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, - int64_t slotCount, int64_t topK) +void initializeMoeWorkspace(torch::Tensor allWorkspaces, int64_t epRank, int64_t epSize) +{ + TORCH_CHECK(allWorkspaces.dim() == 2, "allWorkspaces must be a 2D tensor"); + TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); + + tensorrt_llm::kernels::MoeEpWorldInfo epWorldInfo = {static_cast(epSize), static_cast(epRank)}; + tensorrt_llm::kernels::FusedMoeWorldInfo worldInfo = {epWorldInfo}; + + tensorrt_llm::kernels::FusedMoeWorkspace fusedMoeWorkspace; + tensorrt_llm::kernels::constructWorkspace( + &fusedMoeWorkspace, allWorkspaces.data_ptr(), allWorkspaces.stride(0), epSize); + + tensorrt_llm::kernels::initializeFusedMoeLocalWorkspace(&fusedMoeWorkspace, worldInfo); +} + +std::tuple> +moePrepareOp(torch::Tensor expertsIds, c10::optional expertsStatics, torch::Tensor allWorkspaces, + int64_t maxTokenCountPerRank, int64_t epRank, int64_t epSize, int64_t expertCount, int64_t slotCount, int64_t topK) { CHECK_INPUT(expertsIds, torch::kInt32); TORCH_CHECK(expertCount % 4 == 0, "expertCount must be divisible by 4"); TORCH_CHECK(slotCount % 4 == 0, "slotCount must be divisible by 4"); + TORCH_CHECK(expertCount + 1 <= 512, "expertCount + 1 is larger than 512"); int64_t maxSendRanksPerToken = std::max(epSize, topK); int64_t tokenCount = expertsIds.size(0); @@ -249,18 +187,6 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional scales, c10: torch::Tensor sendRankIndices = torch::empty({maxTokenCountPerRank * maxSendRanksPerToken}, expertsIds.options().dtype(torch::kInt32)); - c10::optional preparedLocalScales; - float* scalesPtr = nullptr; - float* preparedLocalScalesPtr = nullptr; - if (scales.has_value()) - { - CHECK_INPUT(scales.value(), torch::kFloat32); - scalesPtr = scales->data_ptr(); - preparedLocalScales - = torch::empty({maxTokenCountPerRank * epSize, topK}, expertsIds.options().dtype(torch::kFloat32)); - preparedLocalScalesPtr = preparedLocalScales->data_ptr(); - } - int* localExpertStaticsPtr = nullptr; int* gatheredExpertStaticsPtr = nullptr; c10::optional gatheredExpertStatics; @@ -279,8 +205,9 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional scales, c10: tensorrt_llm::kernels::moe_prepare::computeCountAndIndice(expertsIds.data_ptr(), sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), - backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), workspace, tokenCount, - maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); + backwardRecvRankIndices.data_ptr(), recvRankIndices.data_ptr(), localExpertStaticsPtr, + gatheredExpertStaticsPtr, workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, expertCount, epRank, + epSize, stream); tensorrt_llm::kernels::moe_prepare::computeCumsum( sendRankCountCumSum.data_ptr(), RecvRankCountCumSum.data_ptr(), epRank, epSize, stream); @@ -291,54 +218,53 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional scales, c10: recvRankIndices.data_ptr(), gatherRecvRankIndices.data_ptr(), epRank, epSize, maxTokenCountPerRank, stream); - tensorrt_llm::kernels::moe_prepare::allToAllMetadata(expertsIds.data_ptr(), - preparedLocalExpertIds.data_ptr(), scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr, - gatheredExpertStaticsPtr, workspace, sendRankCountCumSum.data_ptr(), sendRankIndices.data_ptr(), - RecvRankCountCumSum.data_ptr(), recvRankIndices.data_ptr(), tokenCount, maxTokenCountPerRank, topK, - expertCount, slotCount, epRank, epSize, stream); - - return std::make_tuple(preparedLocalExpertIds, preparedLocalScales, sendRankCountCumSum, gatherSendRankIndices, - RecvRankCountCumSum, gatherRecvRankIndices, gatherBackwardRecvRankIndices, gatheredExpertStatics); + return std::make_tuple(sendRankCountCumSum, gatherSendRankIndices, RecvRankCountCumSum, gatherRecvRankIndices, + gatherBackwardRecvRankIndices, gatheredExpertStatics); } -} // namespace torch_ext - -TORCH_LIBRARY_FRAGMENT(trtllm, m) +void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum, int64_t maxTokenCountPerRank, + int64_t topK, int64_t slotCount, int64_t epSize) { - m.def( - "moe_comm_prepare_indices(Tensor gathered_target_rank_ids, Tensor? real_rank_token_count_cum_sum, int " - "max_token_count_per_rank, int expert_count, int top_k, int ep_rank, int ep_size) -> (Tensor, Tensor, Tensor, " - "Tensor, Tensor, Tensor)"); -} + CHECK_INPUT(expertsIds, torch::kInt32); + TORCH_CHECK(expertsIds.dim() == 2, "expertsIds must be a 1D tensor"); + TORCH_CHECK( + expertsIds.size(0) == maxTokenCountPerRank * epSize, "expertsIds must have maxTokenCountPerRank * epSize rows"); + TORCH_CHECK(expertsIds.size(1) == topK, "expertsIds must have topK columns"); -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) -{ - m.impl("moe_comm_prepare_indices", &torch_ext::moeCommPrepareIndicesOp); + CHECK_INPUT(recvRankCountCumSum, torch::kInt32); + TORCH_CHECK(recvRankCountCumSum.dim() == 1, "recvRankCountCumSum must be a 1D tensor"); + TORCH_CHECK(recvRankCountCumSum.size(0) == epSize, "recvRankCountCumSum must have epSize elements"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + tensorrt_llm::kernels::moe_prepare::memsetExpertIds(expertsIds.data_ptr(), recvRankCountCumSum.data_ptr(), + static_cast(maxTokenCountPerRank), static_cast(topK), static_cast(slotCount), + static_cast(epSize), stream); } +} // namespace torch_ext + TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "moe_local_gather(Tensor recv_rank_cum_sum, Tensor local_gather_indices, Tensor gathered_expert_ids, Tensor? " - "gathered_scales, Tensor local_expert_ids, Tensor? local_scales, int max_token_count_per_rank, int " - "expert_count, int top_k, int ep_rank, int ep_size) -> ()"); + "moe_comm(Tensor[] inputs, Tensor send_rank_cum_sum, Tensor send_indices, Tensor " + "recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int output_allocation_count, int ep_rank, int " + "ep_size, bool[]? need_zero_output=None) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { - m.impl("moe_local_gather", &torch_ext::moeLocalGatherOp); + m.impl("moe_comm", &torch_ext::moeCommOp); } TORCH_LIBRARY_FRAGMENT(trtllm, m) { - m.def( - "moe_comm(Tensor input, Tensor send_rank_cum_sum, Tensor send_indices, Tensor output, Tensor " - "recv_rank_cum_sum, Tensor recv_indices, Tensor all_workspaces, int ep_rank, int ep_size) -> ()"); + m.def("moe_initialize_workspace(Tensor(a!) all_workspaces, int ep_rank, int ep_size) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { - m.impl("moe_comm", &torch_ext::moeCommOp); + m.impl("moe_initialize_workspace", &torch_ext::initializeMoeWorkspace); } TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -364,9 +290,9 @@ TORCH_LIBRARY_IMPL(trtllm, CompositeExplicitAutograd, m) TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? scales, Tensor? experts_statics, " + "mnnvl_moe_alltoallv_prepare_without_allgather(Tensor experts_ids, Tensor? experts_statics, " "Tensor allWorkspace, int max_token_count_per_rank, int ep_rank, int ep_size, int expert_count, int " - "slot_count, int top_k) -> (Tensor, Tensor?, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)"); + "slot_count, int top_k) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor?)"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) @@ -374,6 +300,19 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("mnnvl_moe_alltoallv_prepare_without_allgather", &torch_ext::moePrepareOp); } +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "memset_expert_ids(Tensor(a!) experts_ids, Tensor recv_rank_count_cumsum, int max_token_count_per_rank, int " + "top_k, " + "int slot_count, int ep_size) -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("memset_expert_ids", &torch_ext::memsetExpertIds); +} + TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def("get_moe_prepare_workspace_size_per_rank(int ep_size) -> int"); diff --git a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp index 8724a860495..4cc7bbd4b3e 100644 --- a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp +++ b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp @@ -16,7 +16,6 @@ */ #include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/kernels/moeCommKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index ae3750597ea..e9c6093942b 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -42,6 +42,8 @@ add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp) add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu) +add_gtest(fusedMoeCommKernelTest fusedMoeCommKernelTest.cpp) + if(NOT ENABLE_MULTI_DEVICE EQUAL 0) add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu) add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu) diff --git a/cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp b/cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp new file mode 100644 index 00000000000..2091d3656e2 --- /dev/null +++ b/cpp/tests/unit_tests/kernels/fusedMoeCommKernelTest.cpp @@ -0,0 +1,1410 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/fusedMoeCommKernels.h" + +using namespace tensorrt_llm::kernels; + +class FusedMoeCommTestBase : public ::testing::Test +{ +protected: + static bool shouldSkip() + { + int deviceCount = tensorrt_llm::common::getDeviceCount(); + if (deviceCount <= 0) + { + return true; + } + int sm = tensorrt_llm::common::getSMVersion(); + if (sm < 90) + { + return true; + } + return false; + } + + void SetUp() override + { + if (shouldSkip()) + { + skipped = true; + GTEST_SKIP() << "Skipping due to no/unsupported GPU"; + } + TLLM_CUDA_CHECK(cudaStreamCreate(&stream)); + std::srand(42); // Initialize random seed + } + + void TearDown() override + { + if (!skipped) + { + TLLM_CUDA_CHECK(cudaStreamDestroy(stream)); + } + } + + bool skipped = false; + cudaStream_t stream = nullptr; + + // Helper function to allocate and initialize test data + template + void allocateAndInitializeData( + T** hostPtr, T** devicePtr, size_t count, std::function generator = nullptr) + { + *hostPtr = new T[count]; + TLLM_CUDA_CHECK(cudaMalloc(devicePtr, count * sizeof(T))); + + if (generator) + { + for (size_t i = 0; i < count; i++) + { + (*hostPtr)[i] = generator(i); + } + } + else + { + // Default initialization with random values + for (size_t i = 0; i < count; i++) + { + if constexpr (std::is_same_v) + { + (*hostPtr)[i] = static_cast(rand()) / RAND_MAX * 10.0f; + } + else if constexpr (std::is_same_v) + { + (*hostPtr)[i] = rand() % 1000; + } + else + { + (*hostPtr)[i] = static_cast(rand() % 100); + } + } + } + + TLLM_CUDA_CHECK(cudaMemcpy(*devicePtr, *hostPtr, count * sizeof(T), cudaMemcpyHostToDevice)); + } + + void cleanup(void* hostPtr, void* devicePtr) + { + delete[] static_cast(hostPtr); + TLLM_CUDA_CHECK(cudaFree(devicePtr)); + } + + // Generate a one-to-one mapping, extending with random permutation if needed + std::vector generateOneToOneMapping(std::vector const& partialMapping, int totalSize) + { + std::vector fullMapping(totalSize); + std::vector used(totalSize, false); + + // First, copy the provided mapping and mark used indices + int providedSize = static_cast(partialMapping.size()); + for (int i = 0; i < std::min(providedSize, totalSize); i++) + { + int target = partialMapping[i]; + if (target >= 0 && target < totalSize && !used[target]) + { + fullMapping[i] = target; + used[target] = true; + } + else + { + // Invalid mapping, will be handled later + fullMapping[i] = -1; + } + } + + // Collect unused indices + std::vector unusedIndices; + for (int i = 0; i < totalSize; i++) + { + if (!used[i]) + { + unusedIndices.push_back(i); + } + } + + // Shuffle unused indices for random assignment + std::srand(42); // Fixed seed for reproducible tests + std::random_shuffle(unusedIndices.begin(), unusedIndices.end()); + + // Fill in any invalid mappings and extend with remaining unused indices + int unusedIdx = 0; + for (int i = 0; i < totalSize; i++) + { + if (i < providedSize && fullMapping[i] == -1) + { + // Fix invalid mapping + if (unusedIdx < unusedIndices.size()) + { + fullMapping[i] = unusedIndices[unusedIdx++]; + } + } + else if (i >= providedSize) + { + // Extend mapping + if (unusedIdx < unusedIndices.size()) + { + fullMapping[i] = unusedIndices[unusedIdx++]; + } + else + { + // Fallback: identity mapping for remaining + fullMapping[i] = i; + } + } + } + + return fullMapping; + } +}; + +// Test class for launchSingleG2S function +class FusedMoeCommG2STest : public FusedMoeCommTestBase +{ +protected: + void runG2STest(int topK, bool hasScales, bool hasBasicFields, int sendFieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, int tokenCount = 4, + int warpsPerBlock = 2) + { + // Setup expert parallel info + MoeExpertParallelInfo expertParallelInfo; + expertParallelInfo.topK = topK; + expertParallelInfo.expertCount = 8; + + // Setup send field info + FusedMoeFieldInfo sendFieldInfo = {}; + sendFieldInfo.isBasicInterleaved = false; + sendFieldInfo.fieldCount = sendFieldCount; + + // Allocate token selected slots and expert scales if needed + int* hostTokenSlots = nullptr; + int* deviceTokenSlots = nullptr; + float* hostScales = nullptr; + float* deviceScales = nullptr; + + if (hasBasicFields) + { + allocateAndInitializeData(&hostTokenSlots, &deviceTokenSlots, tokenCount * topK, + [](size_t i) { return static_cast(i % 8); }); + sendFieldInfo.tokenSelectedSlots = deviceTokenSlots; + + if (hasScales) + { + allocateAndInitializeData(&hostScales, &deviceScales, tokenCount * topK, + [](size_t i) -> float { return 1.0f + static_cast(i) * 0.1f; }); + sendFieldInfo.expertScales = deviceScales; + } + } + + // Setup send field info using new fillFieldInfo helper + std::vector hostFieldPtrs(sendFieldCount); + std::vector deviceFieldPtrs(sendFieldCount); + + for (int i = 0; i < sendFieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + // Allocate field data + uint8_t* hostField; + uint8_t* deviceField; + allocateAndInitializeData(&hostField, &deviceField, fieldSize, + [i](size_t idx) { return static_cast((i * 100 + idx) % 128); }); + + hostFieldPtrs[i] = hostField; + deviceFieldPtrs[i] = deviceField; + + // Use the new fillFieldInfo helper function + sendFieldInfo.fieldsInfo[i].fillFieldInfo(deviceField, elementSize, vectorSize, vectorSize); + } + + // Fill field placement info + sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + + // Compute shared memory size and allocate output buffer + int warpShmSize = sendFieldInfo.computeSingleUncompactSize(topK, hasScales, hasBasicFields); + size_t shmDumpSize = tokenCount * warpShmSize; + size_t shmDumpIntCount = shmDumpSize / sizeof(int); + + int* hostShmDump; + int* deviceShmDump; + allocateAndInitializeData(&hostShmDump, &deviceShmDump, shmDumpIntCount, [](size_t) { return 0; }); + + // Launch G2S kernel with new signature + fused_moe_comm_tests::launchSingleG2S( + sendFieldInfo, expertParallelInfo, tokenCount, deviceShmDump, warpsPerBlock, hasBasicFields, stream); + + TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Copy back results + int* resultShmDump = new int[shmDumpIntCount]; + TLLM_CUDA_CHECK( + cudaMemcpy(resultShmDump, deviceShmDump, shmDumpIntCount * sizeof(int), cudaMemcpyDeviceToHost)); + + // Verify results + verifyG2SResults(resultShmDump, hostTokenSlots, hostScales, hostFieldPtrs, topK, hasScales, hasBasicFields, + sendFieldCount, elementSizes, vectorSizes, tokenCount, warpsPerBlock, warpShmSize); + + // Cleanup + if (hasBasicFields) + { + cleanup(hostTokenSlots, deviceTokenSlots); + if (hasScales) + { + cleanup(hostScales, deviceScales); + } + } + for (int i = 0; i < sendFieldCount; i++) + { + cleanup(hostFieldPtrs[i], deviceFieldPtrs[i]); + } + cleanup(hostShmDump, deviceShmDump); + delete[] resultShmDump; + } + +private: + void verifyG2SResults(int const* shmDump, int const* expectedTokenSlots, float const* expectedScales, + std::vector const& expectedFields, int topK, bool hasScales, bool hasBasicFields, int sendFieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, int tokenCount, + int warpsPerBlock, int warpShmSize) + { + for (int tokenId = 0; tokenId < tokenCount; tokenId++) + { + int const* warpShmData = shmDump + tokenId * warpShmSize / sizeof(int); + + // Verify token slots and scales only if hasBasicFields is true + if (hasBasicFields) + { + // Verify token slots + if (expectedTokenSlots) + { + for (int k = 0; k < topK; k++) + { + int expected = expectedTokenSlots[tokenId * topK + k]; + int actual = warpShmData[k]; + EXPECT_EQ(expected, actual) << "Token slot mismatch at warp=" << tokenId << ", k=" << k; + } + } + + // Verify scales if present + if (hasScales && expectedScales) + { + for (int k = 0; k < topK; k++) + { + float expected = expectedScales[tokenId * topK + k]; + float actual = reinterpret_cast(warpShmData)[topK + k]; + EXPECT_NEAR(expected, actual, 1e-6f) << "Scale mismatch at warp=" << tokenId << ", k=" << k; + } + } + } + + // Additional field verification can be added here if needed + // For now, we just verify that the operation completed successfully + } + } +}; + +// Test class for launchSingleS2G function +class FusedMoeCommS2GTest : public FusedMoeCommTestBase +{ +protected: + void runS2GTest(int topK, bool hasScales, bool hasBasicFields, int recvFieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, int tokenCount = 4, + int warpsPerBlock = 2) + { + // Setup expert parallel info + MoeExpertParallelInfo expertParallelInfo; + expertParallelInfo.topK = topK; + expertParallelInfo.expertCount = 8; + + // Setup recv field info + FusedMoeFieldInfo recvFieldInfo = {}; + recvFieldInfo.isBasicInterleaved = false; + recvFieldInfo.fieldCount = recvFieldCount; + + // Allocate token selected slots and expert scales if needed + int* hostTokenSlots = nullptr; + int* deviceTokenSlots = nullptr; + float* hostScales = nullptr; + float* deviceScales = nullptr; + + if (hasBasicFields) + { + allocateAndInitializeData(&hostTokenSlots, &deviceTokenSlots, tokenCount * topK, + [](size_t) { return 0; }); // Initialize to zero, will be filled by S2G + recvFieldInfo.tokenSelectedSlots = deviceTokenSlots; + + if (hasScales) + { + allocateAndInitializeData(&hostScales, &deviceScales, tokenCount * topK, + [](size_t) { return 0.0f; }); // Initialize to zero, will be filled by S2G + recvFieldInfo.expertScales = deviceScales; + } + } + + // Setup recv field info using new fillFieldInfo helper + std::vector hostFieldPtrs(recvFieldCount); + std::vector deviceFieldPtrs(recvFieldCount); + + for (int i = 0; i < recvFieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + // Allocate field data (initialize to zero, will be filled by S2G) + uint8_t* hostField; + uint8_t* deviceField; + allocateAndInitializeData( + &hostField, &deviceField, fieldSize, [](size_t) { return static_cast(0); }); + + hostFieldPtrs[i] = hostField; + deviceFieldPtrs[i] = deviceField; + + // Use the new fillFieldInfo helper function + recvFieldInfo.fieldsInfo[i].fillFieldInfo(deviceField, elementSize, vectorSize, vectorSize); + } + + // Fill field placement info + recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + + // Compute shared memory size and prepare input data + int warpShmSize = recvFieldInfo.computeSingleUncompactSize(topK, hasScales, hasBasicFields); + size_t shmPreloadSize = tokenCount * warpShmSize; + size_t shmPreloadIntCount = shmPreloadSize / sizeof(int); + + int* hostShmPreload; + int* deviceShmPreload; + allocateAndInitializeData(&hostShmPreload, &deviceShmPreload, shmPreloadIntCount, + [this, topK, hasScales, hasBasicFields, shmPreloadIntCount](size_t idx) + { return this->generateShmPreloadData(idx, topK, hasScales, hasBasicFields, shmPreloadIntCount); }); + + // Launch S2G kernel with new signature + fused_moe_comm_tests::launchSingleS2G( + recvFieldInfo, expertParallelInfo, tokenCount, deviceShmPreload, warpsPerBlock, hasBasicFields, stream); + + TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Copy back results only if hasBasicFields + int* resultTokenSlots = nullptr; + float* resultScales = nullptr; + + if (hasBasicFields) + { + resultTokenSlots = new int[tokenCount * topK]; + TLLM_CUDA_CHECK(cudaMemcpy( + resultTokenSlots, deviceTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost)); + + if (hasScales) + { + resultScales = new float[tokenCount * topK]; + TLLM_CUDA_CHECK( + cudaMemcpy(resultScales, deviceScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost)); + } + } + + // Verify results + verifyS2GResults(resultTokenSlots, resultScales, hostShmPreload, topK, hasScales, hasBasicFields, tokenCount, + warpsPerBlock, warpShmSize); + + // Cleanup + if (hasBasicFields) + { + cleanup(hostTokenSlots, deviceTokenSlots); + if (hasScales) + { + cleanup(hostScales, deviceScales); + } + } + for (int i = 0; i < recvFieldCount; i++) + { + cleanup(hostFieldPtrs[i], deviceFieldPtrs[i]); + } + cleanup(hostShmPreload, deviceShmPreload); + if (resultTokenSlots) + { + delete[] resultTokenSlots; + } + if (resultScales) + { + delete[] resultScales; + } + } + +private: + int generateShmPreloadData(size_t idx, int topK, bool hasScales, bool hasBasicFields, int shmPreloadIntCount) + { + size_t warpIdx = idx / shmPreloadIntCount; + size_t offsetInWarp = idx % shmPreloadIntCount; + + if (hasBasicFields) + { + if (offsetInWarp < topK) + { + // Token slots area + return static_cast(warpIdx * 10 + offsetInWarp); + } + else if (hasScales && offsetInWarp < topK * 2) + { + // Scales area + float scale + = 1.0f + static_cast(warpIdx) * 0.1f + static_cast(offsetInWarp - topK) * 0.01f; + return *reinterpret_cast(&scale); + } + else + { + // Other field data + return static_cast((warpIdx * 1000 + offsetInWarp) % 128); + } + } + else + { + // Only field data when no basic fields + return static_cast((warpIdx * 1000 + offsetInWarp) % 128); + } + } + + void verifyS2GResults(int const* resultTokenSlots, float const* resultScales, int const* shmPreloadData, int topK, + bool hasScales, bool hasBasicFields, int tokenCount, int warpsPerBlock, int warpShmSize) + { + if (!hasBasicFields) + { + // For non-basic fields tests, just verify that the operation completed successfully + // without errors. The actual field data verification would require more complex setup. + return; + } + + for (int tokenId = 0; tokenId < tokenCount; tokenId++) + { + int const* warpShmData = shmPreloadData + tokenId * warpShmSize / sizeof(int); + + // Verify token slots were written correctly + if (resultTokenSlots) + { + for (int k = 0; k < topK; k++) + { + int expected = warpShmData[k]; + int actual = resultTokenSlots[tokenId * topK + k]; + EXPECT_EQ(expected, actual) << "Token slot mismatch at warp=" << tokenId << ", k=" << k; + } + } + + // Verify scales if present + if (hasScales && resultScales) + { + for (int k = 0; k < topK; k++) + { + float expected = reinterpret_cast(warpShmData)[topK + k]; + float actual = resultScales[tokenId * topK + k]; + EXPECT_NEAR(expected, actual, 1e-6f) << "Scale mismatch at warp=" << tokenId << ", k=" << k; + } + } + } + } +}; + +// Test class for launchLoopback function (loopback test) +class FusedMoeCommLoopbackTest : public FusedMoeCommTestBase +{ +protected: + void runLoopbackTest(int topK, bool hasScales, bool hasBasicFields, int fieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, + std::vector const& recvIndexMappingVec, int tokenCount = 4, int warpsPerBlock = 2) + { + // Setup expert parallel info + MoeExpertParallelInfo expertParallelInfo; + expertParallelInfo.topK = topK; + expertParallelInfo.expertCount = 8; + + // Setup field info - for loopback test, send and recv fields should be identical + FusedMoeFieldInfo sendFieldInfo = {}; + sendFieldInfo.isBasicInterleaved = false; + sendFieldInfo.fieldCount = fieldCount; + + FusedMoeFieldInfo recvFieldInfo = {}; + recvFieldInfo.isBasicInterleaved = false; + recvFieldInfo.fieldCount = fieldCount; + + // Allocate token selected slots and expert scales if needed + int* hostSendTokenSlots = nullptr; + int* deviceSendTokenSlots = nullptr; + float* hostSendScales = nullptr; + float* deviceSendScales = nullptr; + int* hostRecvTokenSlots = nullptr; + int* deviceRecvTokenSlots = nullptr; + float* hostRecvScales = nullptr; + float* deviceRecvScales = nullptr; + + if (hasBasicFields) + { + // Send side basic fields + allocateAndInitializeData(&hostSendTokenSlots, &deviceSendTokenSlots, tokenCount * topK, + [](size_t i) { return static_cast(i % 8); }); + sendFieldInfo.tokenSelectedSlots = deviceSendTokenSlots; + + // Recv side basic fields (initialized to zero, will be filled by loopback) + allocateAndInitializeData( + &hostRecvTokenSlots, &deviceRecvTokenSlots, tokenCount * topK, [](size_t) { return 0; }); + recvFieldInfo.tokenSelectedSlots = deviceRecvTokenSlots; + + if (hasScales) + { + allocateAndInitializeData(&hostSendScales, &deviceSendScales, tokenCount * topK, + [](size_t i) -> float { return 1.0f + static_cast(i) * 0.1f; }); + sendFieldInfo.expertScales = deviceSendScales; + + allocateAndInitializeData( + &hostRecvScales, &deviceRecvScales, tokenCount * topK, [](size_t) { return 0.0f; }); + recvFieldInfo.expertScales = deviceRecvScales; + } + } + + // Setup field info - both send and recv use same layout for loopback + std::vector hostSendFieldPtrs(fieldCount); + std::vector deviceSendFieldPtrs(fieldCount); + std::vector hostRecvFieldPtrs(fieldCount); + std::vector deviceRecvFieldPtrs(fieldCount); + + for (int i = 0; i < fieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + // Allocate send field data with specific pattern + uint8_t* hostSendField; + uint8_t* deviceSendField; + allocateAndInitializeData(&hostSendField, &deviceSendField, fieldSize, + [i](size_t idx) { return static_cast((i * 100 + idx + 1) % 128); }); + + // Allocate recv field data (initially zero, will be filled by loopback) + uint8_t* hostRecvField; + uint8_t* deviceRecvField; + allocateAndInitializeData( + &hostRecvField, &deviceRecvField, fieldSize, [](size_t) { return static_cast(0); }); + + hostSendFieldPtrs[i] = hostSendField; + deviceSendFieldPtrs[i] = deviceSendField; + hostRecvFieldPtrs[i] = hostRecvField; + deviceRecvFieldPtrs[i] = deviceRecvField; + + // Fill field info for both send and recv + sendFieldInfo.fieldsInfo[i].fillFieldInfo(deviceSendField, elementSize, vectorSize, vectorSize); + recvFieldInfo.fieldsInfo[i].fillFieldInfo(deviceRecvField, elementSize, vectorSize, vectorSize); + } + + // Fill field placement info + sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + + // Setup recvIndexMapping - ensure one-to-one mapping + std::vector fullMapping = generateOneToOneMapping(recvIndexMappingVec, tokenCount); + int* hostRecvIndexMapping; + int* deviceRecvIndexMapping; + allocateAndInitializeData(&hostRecvIndexMapping, &deviceRecvIndexMapping, tokenCount, + [&fullMapping](size_t i) { return fullMapping[i]; }); + + // Launch loopback kernel + fused_moe_comm_tests::launchLoopback(sendFieldInfo, recvFieldInfo, expertParallelInfo, deviceRecvIndexMapping, + tokenCount, warpsPerBlock, hasBasicFields, stream); + + TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Copy back results and verify + verifyLoopbackResults(hostSendTokenSlots, hostSendScales, hostSendFieldPtrs, hostRecvFieldPtrs, + deviceRecvTokenSlots, deviceRecvScales, deviceRecvFieldPtrs, fullMapping, topK, hasScales, hasBasicFields, + fieldCount, elementSizes, vectorSizes, tokenCount); + + // Cleanup + if (hasBasicFields) + { + cleanup(hostSendTokenSlots, deviceSendTokenSlots); + cleanup(hostRecvTokenSlots, deviceRecvTokenSlots); + if (hasScales) + { + cleanup(hostSendScales, deviceSendScales); + cleanup(hostRecvScales, deviceRecvScales); + } + } + for (int i = 0; i < fieldCount; i++) + { + cleanup(hostSendFieldPtrs[i], deviceSendFieldPtrs[i]); + cleanup(hostRecvFieldPtrs[i], deviceRecvFieldPtrs[i]); + } + cleanup(hostRecvIndexMapping, deviceRecvIndexMapping); + } + +private: + void verifyLoopbackResults(int const* expectedSendTokenSlots, float const* expectedSendScales, + std::vector const& expectedSendFields, std::vector const& hostRecvFields, + int* deviceRecvTokenSlots, float* deviceRecvScales, std::vector const& deviceRecvFields, + std::vector const& fullMapping, int topK, bool hasScales, bool hasBasicFields, int fieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, int tokenCount) + { + // Copy back device results for verification + int* resultRecvTokenSlots = nullptr; + float* resultRecvScales = nullptr; + + if (hasBasicFields) + { + resultRecvTokenSlots = new int[tokenCount * topK]; + TLLM_CUDA_CHECK(cudaMemcpy( + resultRecvTokenSlots, deviceRecvTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost)); + + if (hasScales) + { + resultRecvScales = new float[tokenCount * topK]; + TLLM_CUDA_CHECK(cudaMemcpy( + resultRecvScales, deviceRecvScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost)); + } + } + + // Copy back field data + std::vector resultRecvFields(fieldCount); + for (int i = 0; i < fieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + resultRecvFields[i] = new uint8_t[fieldSize]; + TLLM_CUDA_CHECK(cudaMemcpy(resultRecvFields[i], deviceRecvFields[i], fieldSize, cudaMemcpyDeviceToHost)); + } + + // Verify the loopback: recv[fullMapping[sendIndex]] should equal send[sendIndex] + int tokenSlotErrorCount = 0; + int scaleErrorCount = 0; + std::vector fieldErrorCounts(fieldCount, 0); + + for (int sendIndex = 0; sendIndex < tokenCount; sendIndex++) + { + int recvIndex = fullMapping[sendIndex]; + ASSERT_GE(recvIndex, 0) << "Invalid recv index mapping at " << sendIndex; + ASSERT_LT(recvIndex, tokenCount) << "Recv index out of bounds at " << sendIndex; + + // Verify basic fields if present + if (hasBasicFields) + { + // Verify token slots + if (expectedSendTokenSlots && resultRecvTokenSlots) + { + for (int k = 0; k < topK; k++) + { + int expected = expectedSendTokenSlots[sendIndex * topK + k]; + int actual = resultRecvTokenSlots[recvIndex * topK + k]; + EXPECT_EQ(expected, actual) << "Token slot loopback mismatch: send[" << sendIndex << "][" << k + << "] -> recv[" << recvIndex << "][" << k << "]"; + } + } + + // Verify scales if present + if (hasScales && expectedSendScales && resultRecvScales) + { + for (int k = 0; k < topK; k++) + { + float expected = expectedSendScales[sendIndex * topK + k]; + float actual = resultRecvScales[recvIndex * topK + k]; + EXPECT_NEAR(expected, actual, 1e-6f) << "Scale loopback mismatch: send[" << sendIndex << "][" + << k << "] -> recv[" << recvIndex << "][" << k << "]"; + } + } + } + + // Verify field data + for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++) + { + size_t elementSize = elementSizes[fieldIdx % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[fieldIdx % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize; + + uint8_t const* expectedSendField = static_cast(expectedSendFields[fieldIdx]); + uint8_t const* actualRecvField = resultRecvFields[fieldIdx]; + + for (size_t byteIdx = 0; byteIdx < fieldSize; byteIdx++) + { + uint8_t expected = expectedSendField[sendIndex * fieldSize + byteIdx]; + uint8_t actual = actualRecvField[recvIndex * fieldSize + byteIdx]; + EXPECT_EQ(expected, actual) + << "Field loopback mismatch: field[" << fieldIdx << "] send[" << sendIndex << "][" << byteIdx + << "] -> recv[" << recvIndex << "][" << byteIdx << "]"; + } + } + } + + // Cleanup temporary arrays + if (resultRecvTokenSlots) + delete[] resultRecvTokenSlots; + if (resultRecvScales) + delete[] resultRecvScales; + for (int i = 0; i < fieldCount; i++) + { + if (resultRecvFields[i]) + delete[] resultRecvFields[i]; + } + } +}; + +// Tests for G2S functionality +TEST_F(FusedMoeCommG2STest, BasicG2SWithoutScales) +{ + runG2STest(2, false, true, 1, {4}, {64}); // topK=2, no scales, has basic fields, 1 field, 4-byte elements, 64 units +} + +TEST_F(FusedMoeCommG2STest, BasicG2SWithScales) +{ + runG2STest( + 4, true, true, 1, {4}, {32}); // topK=4, with scales, has basic fields, 1 field, 4-byte elements, 32 units +} + +TEST_F(FusedMoeCommG2STest, MultipleFieldsVariousAlignments) +{ + runG2STest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}); // Multiple fields with different element sizes +} + +TEST_F(FusedMoeCommG2STest, LargeTopK) +{ + runG2STest(8, true, true, 2, {4, 8}, {128, 256}); // Large topK value +} + +TEST_F(FusedMoeCommG2STest, PerfectAlignmentFields) +{ + runG2STest(4, false, true, 2, {16}, {32}); // 16-byte aligned fields +} + +TEST_F(FusedMoeCommG2STest, MixedAlignmentTypes) +{ + runG2STest(3, true, true, 4, {8, 4, 2, 1}, {64, 32, 16, 8}); // All alignment types + runG2STest(3, true, true, 4, {1, 2, 4, 8}, {63, 30, 17, 9}, 32); // All alignment types +} + +TEST_F(FusedMoeCommG2STest, SingleByteAlignment) +{ + runG2STest(2, false, true, 2, {1}, {128}); // Single byte alignment +} + +TEST_F(FusedMoeCommG2STest, EdgeCaseTopKOne) +{ + runG2STest(1, false, true, 1, {4}, {16}); // Minimal topK +} + +TEST_F(FusedMoeCommG2STest, EdgeCaseNoExtraFields) +{ + runG2STest(2, true, true, 0, {}, {}); // Only basic fields (token slots + scales) +} + +TEST_F(FusedMoeCommG2STest, LargeTokenCount) +{ + runG2STest(4, true, true, 2, {4, 8}, {64, 128}, 16, 4); // 16 tokens, 4 warps per block +} + +// New tests for no basic fields scenario +TEST_F(FusedMoeCommG2STest, G2SWithoutBasicFields) +{ + runG2STest(0, false, false, 2, {4, 8}, {32, 64}); // No basic fields, only field data +} + +TEST_F(FusedMoeCommG2STest, G2SWithoutBasicFieldsLargeFields) +{ + runG2STest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}); // No basic fields, large field data +} + +// Tests for S2G functionality +TEST_F(FusedMoeCommS2GTest, BasicS2GWithoutScales) +{ + runS2GTest(2, false, true, 1, {4}, {64}); // topK=2, no scales, has basic fields, 1 field, 4-byte elements +} + +TEST_F(FusedMoeCommS2GTest, BasicS2GWithScales) +{ + runS2GTest(4, true, true, 1, {4}, {32}); // topK=4, with scales, has basic fields, 1 field, 4-byte elements +} + +TEST_F(FusedMoeCommS2GTest, MultipleFieldsVariousAlignments) +{ + runS2GTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}); // Multiple fields with different element sizes +} + +TEST_F(FusedMoeCommS2GTest, LargeTopK) +{ + runS2GTest(8, true, true, 2, {4, 8}, {128, 256}); // Large topK value +} + +TEST_F(FusedMoeCommS2GTest, PerfectAlignmentFields) +{ + runS2GTest(4, false, true, 2, {16}, {32}); // 16-byte aligned fields +} + +TEST_F(FusedMoeCommS2GTest, MixedAlignmentTypes) +{ + runS2GTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}); // All alignment types + runS2GTest(3, true, true, 4, {1, 2, 4, 8}, {63, 30, 17, 9}, 32); // All alignment types +} + +TEST_F(FusedMoeCommS2GTest, SingleByteAlignment) +{ + runS2GTest(2, false, true, 2, {1}, {128}); // Single byte alignment +} + +TEST_F(FusedMoeCommS2GTest, EdgeCaseTopKOne) +{ + runS2GTest(1, false, true, 1, {4}, {16}); // Minimal topK +} + +TEST_F(FusedMoeCommS2GTest, EdgeCaseNoExtraFields) +{ + runS2GTest(2, true, true, 0, {}, {}); // Only basic fields (token slots + scales) +} + +TEST_F(FusedMoeCommS2GTest, LargeTokenCount) +{ + runS2GTest(4, true, true, 2, {4, 8}, {64, 128}, 16, 4); // 16 tokens, 4 warps per block +} + +// New tests for no basic fields scenario +TEST_F(FusedMoeCommS2GTest, S2GWithoutBasicFields) +{ + runS2GTest(0, false, false, 2, {4, 8}, {32, 64}); // No basic fields, only field data +} + +TEST_F(FusedMoeCommS2GTest, S2GWithoutBasicFieldsLargeFields) +{ + runS2GTest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}); // No basic fields, large field data +} + +// Tests for G2S+Pack+Unpack+S2G loopback functionality +TEST_F(FusedMoeCommLoopbackTest, BasicLoopbackWithoutScales) +{ + std::vector mapping = {0, 1, 2, 3}; // Identity mapping + runLoopbackTest(2, false, true, 1, {4}, {64}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, BasicLoopbackWithScales) +{ + std::vector mapping = {0, 1, 2, 3}; // Identity mapping + runLoopbackTest(4, true, true, 1, {4}, {32}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackWithReordering) +{ + std::vector mapping = {3, 0, 2, 1}; // Reorder mapping + runLoopbackTest(2, true, true, 2, {4, 8}, {32, 64}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackWithReverseMapping) +{ + std::vector mapping = {3, 2, 1, 0}; // Reverse mapping + runLoopbackTest(3, false, true, 1, {2}, {128}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackMultipleFieldsVariousAlignments) +{ + std::vector mapping = {1, 3, 0, 2}; // Complex reordering + runLoopbackTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackLargeTopK) +{ + std::vector mapping = {2, 0, 3, 1}; // Reorder mapping + runLoopbackTest(8, true, true, 2, {4, 8}, {128, 256}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackPerfectAlignmentFields) +{ + std::vector mapping = {0, 2, 1, 3}; // Partial reordering + runLoopbackTest(4, false, true, 2, {16}, {32}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackMixedAlignmentTypes) +{ + std::vector mapping = {1, 0, 3, 2}; // Pair swap + runLoopbackTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackSingleByteAlignment) +{ + std::vector mapping = {2, 3, 0, 1}; // Cyclic shift + runLoopbackTest(2, false, true, 2, {1}, {128}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackEdgeCaseTopKOne) +{ + std::vector mapping = {1, 0, 3, 2}; // Simple reordering + runLoopbackTest(1, false, true, 1, {4}, {16}, mapping); +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackEdgeCaseNoExtraFields) +{ + std::vector mapping = {3, 1, 0, 2}; // Random reordering + runLoopbackTest(2, true, true, 0, {}, {}, mapping); // Only basic fields +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackLargeTokenCount) +{ + std::vector mapping = {7, 0, 5, 2, 3, 6, 1, 4, 15, 8, 11, 10, 9, 14, 13, 12}; // Complex 16-token mapping + runLoopbackTest(4, true, true, 2, {4, 8}, {64, 128}, mapping, 16, 4); +} + +// New tests for no basic fields scenario +TEST_F(FusedMoeCommLoopbackTest, LoopbackWithoutBasicFields) +{ + std::vector mapping = {1, 3, 0, 2}; // Reorder mapping + runLoopbackTest(0, false, false, 2, {4, 8}, {32, 64}, mapping); // No basic fields, only field data +} + +TEST_F(FusedMoeCommLoopbackTest, LoopbackWithoutBasicFieldsLargeFields) +{ + std::vector mapping = {2, 0, 3, 1}; // Reorder mapping + runLoopbackTest(0, false, false, 3, {1, 4, 16}, {128, 256, 512}, mapping); // No basic fields, large field data +} + +// Test class for launchLocalFifoSendRecv function (FIFO-based local send/recv test) +class FusedMoeCommLocalFifoSendRecvTest : public FusedMoeCommTestBase +{ +protected: + void runLocalFifoSendRecvTest(int topK, bool hasScales, bool hasBasicFields, int fieldCount, + std::vector const& elementSizes, std::vector const& vectorSizes, + std::vector const& sendIndexMappingVec, std::vector const& recvIndexMappingVec, int tokenCount = 4, + int warpsPerBlock = 2, int blockChannelCount = 1) + { + // Setup expert parallel info + MoeExpertParallelInfo expertParallelInfo; + expertParallelInfo.topK = topK; + expertParallelInfo.expertCount = 8; + + // Setup field info for send and receive sides + FusedMoeFieldInfo sendFieldInfo = {}; + sendFieldInfo.isBasicInterleaved = false; + sendFieldInfo.fieldCount = fieldCount; + + FusedMoeFieldInfo recvFieldInfo = {}; + recvFieldInfo.isBasicInterleaved = false; + recvFieldInfo.fieldCount = fieldCount; + + // Allocate token selected slots and expert scales if needed + int* hostSendTokenSlots = nullptr; + int* deviceSendTokenSlots = nullptr; + float* hostSendScales = nullptr; + float* deviceSendScales = nullptr; + int* hostRecvTokenSlots = nullptr; + int* deviceRecvTokenSlots = nullptr; + float* hostRecvScales = nullptr; + float* deviceRecvScales = nullptr; + + if (hasBasicFields) + { + // Send side basic fields + allocateAndInitializeData(&hostSendTokenSlots, &deviceSendTokenSlots, tokenCount * topK, + [](size_t i) { return static_cast(i % 8); }); + sendFieldInfo.tokenSelectedSlots = deviceSendTokenSlots; + + // Recv side basic fields (initialized to zero, will be filled by communication) + allocateAndInitializeData( + &hostRecvTokenSlots, &deviceRecvTokenSlots, tokenCount * topK, [](size_t) { return 0; }); + recvFieldInfo.tokenSelectedSlots = deviceRecvTokenSlots; + + if (hasScales) + { + allocateAndInitializeData(&hostSendScales, &deviceSendScales, tokenCount * topK, + [](size_t i) -> float { return 1.0f + static_cast(i) * 0.1f; }); + sendFieldInfo.expertScales = deviceSendScales; + + allocateAndInitializeData( + &hostRecvScales, &deviceRecvScales, tokenCount * topK, [](size_t) { return 0.0f; }); + recvFieldInfo.expertScales = deviceRecvScales; + } + } + + // Setup field info for additional fields + std::vector hostSendFieldPtrs(fieldCount); + std::vector deviceSendFieldPtrs(fieldCount); + std::vector hostRecvFieldPtrs(fieldCount); + std::vector deviceRecvFieldPtrs(fieldCount); + + for (int i = 0; i < fieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + // Allocate send field data with specific pattern + uint8_t* hostSendField; + uint8_t* deviceSendField; + allocateAndInitializeData(&hostSendField, &deviceSendField, fieldSize, + [i](size_t idx) { return static_cast((i * 100 + idx + 1) % 128); }); + + // Allocate recv field data (initially zero, will be filled by communication) + uint8_t* hostRecvField; + uint8_t* deviceRecvField; + allocateAndInitializeData( + &hostRecvField, &deviceRecvField, fieldSize, [](size_t) { return static_cast(0); }); + + hostSendFieldPtrs[i] = hostSendField; + deviceSendFieldPtrs[i] = deviceSendField; + hostRecvFieldPtrs[i] = hostRecvField; + deviceRecvFieldPtrs[i] = deviceRecvField; + + // Fill field info + sendFieldInfo.fieldsInfo[i].fillFieldInfo(deviceSendField, elementSize, vectorSize, vectorSize); + recvFieldInfo.fieldsInfo[i].fillFieldInfo(deviceRecvField, elementSize, vectorSize, vectorSize); + } + + // Fill field placement info + sendFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + recvFieldInfo.fillFieldPlacementInfo(topK, hasBasicFields); + + // Setup sendIndexMapping and recvIndexMapping - ensure one-to-one mappings + std::vector fullSendMapping = generateOneToOneMapping(sendIndexMappingVec, tokenCount); + std::vector fullRecvMapping = generateOneToOneMapping(recvIndexMappingVec, tokenCount); + + int* hostSendIndexMapping; + int* deviceSendIndexMapping; + int* hostRecvIndexMapping; + int* deviceRecvIndexMapping; + + allocateAndInitializeData(&hostSendIndexMapping, &deviceSendIndexMapping, tokenCount, + [&fullSendMapping](size_t i) { return fullSendMapping[i]; }); + allocateAndInitializeData(&hostRecvIndexMapping, &deviceRecvIndexMapping, tokenCount, + [&fullRecvMapping](size_t i) { return fullRecvMapping[i]; }); + + // Setup workspace for FIFO communication + FusedMoeWorkspace fusedMoeWorkspace; + int totalChannelCount = blockChannelCount * warpsPerBlock; + size_t workspaceSizePerRank = FusedMoeWorkspace::computeWorkspaceSizePreRank(1, totalChannelCount); + size_t totalWorkspaceSize = workspaceSizePerRank; + fusedMoeWorkspace.rankStrideInU64 = workspaceSizePerRank / sizeof(uint64_t); + fusedMoeWorkspace.channelCount = totalChannelCount; + + TLLM_CUDA_CHECK(cudaMalloc(&fusedMoeWorkspace.workspacePtr, totalWorkspaceSize)); + + // Initialize workspace + FusedMoeWorldInfo worldInfo; + worldInfo.epInfo.epRank = 0; + worldInfo.epInfo.epSize = 1; + fusedMoeWorkspace.initializeLocalWorkspace(worldInfo); + + // Launch FIFO send/recv kernel + fused_moe_comm_tests::launchLocalFifoSendRecv(sendFieldInfo, recvFieldInfo, expertParallelInfo, + deviceSendIndexMapping, deviceRecvIndexMapping, fusedMoeWorkspace, tokenCount, warpsPerBlock, + blockChannelCount, hasBasicFields, stream); + + TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Copy back results and verify + verifyLocalFifoSendRecvResults(hostSendTokenSlots, hostSendScales, hostSendFieldPtrs, hostRecvFieldPtrs, + deviceRecvTokenSlots, deviceRecvScales, deviceRecvFieldPtrs, fullSendMapping, fullRecvMapping, topK, + hasScales, hasBasicFields, fieldCount, elementSizes, vectorSizes, tokenCount); + + // Cleanup + if (hasBasicFields) + { + cleanup(hostSendTokenSlots, deviceSendTokenSlots); + cleanup(hostRecvTokenSlots, deviceRecvTokenSlots); + if (hasScales) + { + cleanup(hostSendScales, deviceSendScales); + cleanup(hostRecvScales, deviceRecvScales); + } + } + for (int i = 0; i < fieldCount; i++) + { + cleanup(hostSendFieldPtrs[i], deviceSendFieldPtrs[i]); + cleanup(hostRecvFieldPtrs[i], deviceRecvFieldPtrs[i]); + } + cleanup(hostSendIndexMapping, deviceSendIndexMapping); + cleanup(hostRecvIndexMapping, deviceRecvIndexMapping); + TLLM_CUDA_CHECK(cudaFree(fusedMoeWorkspace.workspacePtr)); + } + +private: + void verifyLocalFifoSendRecvResults(int const* expectedSendTokenSlots, float const* expectedSendScales, + std::vector const& expectedSendFields, std::vector const& hostRecvFields, + int* deviceRecvTokenSlots, float* deviceRecvScales, std::vector const& deviceRecvFields, + std::vector const& fullSendMapping, std::vector const& fullRecvMapping, int topK, bool hasScales, + bool hasBasicFields, int fieldCount, std::vector const& elementSizes, + std::vector const& vectorSizes, int tokenCount) + { + // Copy back device results for verification + int* resultRecvTokenSlots = nullptr; + float* resultRecvScales = nullptr; + + if (hasBasicFields) + { + resultRecvTokenSlots = new int[tokenCount * topK]; + TLLM_CUDA_CHECK(cudaMemcpy( + resultRecvTokenSlots, deviceRecvTokenSlots, tokenCount * topK * sizeof(int), cudaMemcpyDeviceToHost)); + + if (hasScales) + { + resultRecvScales = new float[tokenCount * topK]; + TLLM_CUDA_CHECK(cudaMemcpy( + resultRecvScales, deviceRecvScales, tokenCount * topK * sizeof(float), cudaMemcpyDeviceToHost)); + } + } + + // Copy back field data + std::vector resultRecvFields(fieldCount); + for (int i = 0; i < fieldCount; i++) + { + size_t elementSize = elementSizes[i % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[i % vectorSizes.size()]; + size_t fieldSize = elementSize * vectorSize * tokenCount; + + resultRecvFields[i] = new uint8_t[fieldSize]; + TLLM_CUDA_CHECK(cudaMemcpy(resultRecvFields[i], deviceRecvFields[i], fieldSize, cudaMemcpyDeviceToHost)); + } + + // Verify the FIFO send/recv with independent mappings: + // For logical index i: + // - Send side reads from fullSendMapping[i] + // - Recv side writes to fullRecvMapping[i] + // So we need to verify: recv[fullRecvMapping[i]] should equal send[fullSendMapping[i]] + int tokenSlotErrorCount = 0; + int scaleErrorCount = 0; + std::vector fieldErrorCounts(fieldCount, 0); + + for (int logicalIndex = 0; logicalIndex < tokenCount; logicalIndex++) + { + int actualSendIndex = fullSendMapping[logicalIndex]; + int actualRecvIndex = fullRecvMapping[logicalIndex]; + if (actualSendIndex < 0 || actualSendIndex >= tokenCount || actualRecvIndex < 0 + || actualRecvIndex >= tokenCount) + continue; + + // Verify token selected slots + if (hasBasicFields) + { + for (int k = 0; k < topK; k++) + { + int expectedSlot = expectedSendTokenSlots[actualSendIndex * topK + k]; + int actualSlot = resultRecvTokenSlots[actualRecvIndex * topK + k]; + if (expectedSlot != actualSlot) + { + tokenSlotErrorCount++; + if (tokenSlotErrorCount <= 16) + { + EXPECT_EQ(expectedSlot, actualSlot) + << "Token slot mismatch at logicalIndex=" << logicalIndex + << ", actualSendIndex=" << actualSendIndex << ", actualRecvIndex=" << actualRecvIndex + << ", k=" << k; + } + } + } + + // Verify expert scales + if (hasScales) + { + for (int k = 0; k < topK; k++) + { + float expectedScale = expectedSendScales[actualSendIndex * topK + k]; + float actualScale = resultRecvScales[actualRecvIndex * topK + k]; + if (std::abs(expectedScale - actualScale) > 1e-6f) + { + scaleErrorCount++; + if (scaleErrorCount <= 16) + { + EXPECT_NEAR(expectedScale, actualScale, 1e-6f) + << "Scale mismatch at logicalIndex=" << logicalIndex + << ", actualSendIndex=" << actualSendIndex + << ", actualRecvIndex=" << actualRecvIndex << ", k=" << k; + } + } + } + } + } + + // Verify additional fields + for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++) + { + size_t elementSize = elementSizes[fieldIdx % elementSizes.size()]; + uint16_t vectorSize = vectorSizes[fieldIdx % vectorSizes.size()]; + size_t fieldSizePerToken = elementSize * vectorSize; + + uint8_t const* expectedFieldData = static_cast(expectedSendFields[fieldIdx]); + uint8_t const* actualFieldData = resultRecvFields[fieldIdx]; + + for (size_t byteIdx = 0; byteIdx < fieldSizePerToken; byteIdx++) + { + uint8_t expected = expectedFieldData[actualSendIndex * fieldSizePerToken + byteIdx]; + uint8_t actual = actualFieldData[actualRecvIndex * fieldSizePerToken + byteIdx]; + if (expected != actual) + { + fieldErrorCounts[fieldIdx]++; + if (fieldErrorCounts[fieldIdx] <= 16) + { + EXPECT_EQ(static_cast(expected), static_cast(actual)) + << "Field[" << fieldIdx << "] mismatch at logicalIndex=" << logicalIndex + << ", actualSendIndex=" << actualSendIndex << ", actualRecvIndex=" << actualRecvIndex + << ", byteIdx=" << byteIdx; + } + } + } + } + } + + // Print error summary for counts exceeding 16 + if (tokenSlotErrorCount > 16) + { + ADD_FAILURE() << "Token slot errors: Showed first 16 of " << tokenSlotErrorCount << " total mismatches."; + } + if (scaleErrorCount > 16) + { + ADD_FAILURE() << "Scale errors: Showed first 16 of " << scaleErrorCount << " total mismatches."; + } + for (int fieldIdx = 0; fieldIdx < fieldCount; fieldIdx++) + { + if (fieldErrorCounts[fieldIdx] > 16) + { + ADD_FAILURE() << "Field[" << fieldIdx << "] errors: Showed first 16 of " << fieldErrorCounts[fieldIdx] + << " total mismatches."; + } + } + + // Cleanup temporary arrays + if (resultRecvTokenSlots) + delete[] resultRecvTokenSlots; + if (resultRecvScales) + delete[] resultRecvScales; + for (int i = 0; i < fieldCount; i++) + { + if (resultRecvFields[i]) + delete[] resultRecvFields[i]; + } + } +}; + +// Tests for Local FIFO Send/Recv functionality with Packed Protocol +TEST_F(FusedMoeCommLocalFifoSendRecvTest, BasicFifoSendRecvPackedProtocol) +{ + std::vector sendMapping = {0, 1, 2, 3}; // Identity mapping for send + std::vector recvMapping = {2, 3, 0, 1}; // Rotate mapping for recv + runLocalFifoSendRecvTest(2, false, true, 1, {4}, {64}, sendMapping, recvMapping, 4, 1, 1); // Packed protocol +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, BasicFifoSendRecvWithScalesPackedProtocol) +{ + std::vector sendMapping = {1, 2, 3, 0}; // Rotate send mapping + std::vector recvMapping = {3, 0, 1, 2}; // Opposite rotation for recv + runLocalFifoSendRecvTest( + 4, true, true, 1, {4}, {32}, sendMapping, recvMapping, 4, 2, 1); // With scales, Packed protocol +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvWithReorderingPackedProtocol) +{ + std::vector sendMapping = {3, 0, 2, 1}; // Random send reorder + std::vector recvMapping = {0, 3, 1, 2}; // Different recv reorder + runLocalFifoSendRecvTest(2, true, true, 2, {4, 8}, {32, 64}, sendMapping, recvMapping, 256, 2, 2); +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvMultipleFieldsPackedProtocol) +{ + std::vector mapping = {1, 3, 0, 2}; // Complex reordering + runLocalFifoSendRecvTest(2, true, true, 3, {1, 2, 4}, {16, 32, 64}, mapping, mapping, 256, 2, 2); +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvLargeTopKPackedProtocol) +{ + std::vector mapping = {2, 0, 3, 1}; // Reorder mapping + runLocalFifoSendRecvTest(8, true, true, 2, {4, 8}, {128, 256}, mapping, mapping, 512, 3, 2); +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvWithoutBasicFieldsPackedProtocol) +{ + std::vector sendMapping = {1, 3, 0, 2}; // Send reorder mapping + std::vector recvMapping = {3, 2, 1, 0}; // Reverse recv mapping + runLocalFifoSendRecvTest( + 0, false, false, 2, {4, 8}, {32, 64}, sendMapping, recvMapping, 256, 2, 2); // No basic fields, Packed protocol +} + +// Mixed alignment tests +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvMixedAlignmentsPackedProtocol) +{ + std::vector mapping = {1, 0, 3, 2}; // Pair swap + runLocalFifoSendRecvTest(3, true, true, 4, {1, 2, 4, 8}, {8, 16, 32, 64}, mapping, mapping, 512, 2, 2); +} + +// Edge cases +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvEdgeCaseTopKOnePackedProtocol) +{ + std::vector mapping = {1, 0, 3, 2}; // Simple reordering + runLocalFifoSendRecvTest(1, false, true, 1, {4}, {16}, mapping, mapping, 128, 2, 1); +} + +// Only basic fields cases +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvEdgeCaseNoExtraFieldsPackedProtocol) +{ + std::vector mapping = {3, 1, 0, 2}; // Random reordering + runLocalFifoSendRecvTest(2, true, true, 0, {}, {}, mapping, mapping, 256, 2, 2); // Only basic fields +} + +// Large scale tests +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvLargeTokenCountPackedProtocol) +{ + std::vector sendMapping = {7, 0, 5, 2, 3, 6, 1, 4, 15, 8, 11, 10, 9, 14, 13, 12}; // Complex send mapping + std::vector recvMapping = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; // Reverse recv mapping + runLocalFifoSendRecvTest( + 4, true, true, 2, {4, 8}, {64, 128}, sendMapping, recvMapping, 1024, 3, 3); // Large scale test +} + +// Perfect alignment tests +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvPerfectAlignmentPackedProtocol) +{ + std::vector sendMapping = {2, 0, 3, 1}; // Different send reordering + std::vector recvMapping = {1, 3, 0, 2}; // Different recv reordering + runLocalFifoSendRecvTest(4, false, true, 2, {16}, {32}, sendMapping, recvMapping, 256, 2, 3); +} + +// Single byte alignment tests +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvSmallSingleByteAlignmentPackedProtocol) +{ + std::vector mapping = {2, 3, 0, 1}; // Cyclic shift + runLocalFifoSendRecvTest(2, false, true, 1, {1}, {127}, mapping, mapping, 4, 1, 1); +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvSingleByteAlignmentPackedProtocol) +{ + std::vector mapping = {2, 3, 0, 1}; // Cyclic shift + runLocalFifoSendRecvTest(2, false, true, 2, {1}, {127}, mapping, mapping, 256, 3, 1); +} + +// Stress tests +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTestManyChannelsPackedProtocol) +{ + std::vector mapping = {7, 2, 5, 0, 3, 6, 1, 4}; // Complex mapping + runLocalFifoSendRecvTest(4, true, true, 2, {8, 16}, {128, 256}, mapping, mapping, 512, 3, 4); // Many channels +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTest2ManyChannelsPackedProtocol) +{ + std::vector mapping = {7, 2, 5, 0, 3, 6, 1, 4}; // Complex mapping + runLocalFifoSendRecvTest( + 4, true, true, 2, {2, 4, 8, 16}, {7, 15, 31, 255}, mapping, mapping, 4096, 1, 2); // Many channels +} + +TEST_F(FusedMoeCommLocalFifoSendRecvTest, FifoSendRecvStressTestManyWarpsPackedProtocol) +{ + std::vector mapping = {1, 0, 3, 2}; // Simple reordering + runLocalFifoSendRecvTest(2, false, true, 1, {4}, {64}, mapping, mapping, 256, 4, 2); // Many warps per block +} diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 39f6deac4c5..d30b7316c39 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -17,7 +17,7 @@ import platform import sys from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union import pynvml import torch @@ -366,6 +366,10 @@ def get_moe_workspaces(mapping: Mapping): ) MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank) MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64) + torch.ops.trtllm.moe_initialize_workspace( + MnnvlMoe.moe_workspace_tensor, mapping.tp_rank, mapping.tp_size + ) + MnnvlMoe.moe_workspace.comm.barrier() return MnnvlMoe.moe_workspace_tensor @staticmethod @@ -394,7 +398,6 @@ def compute_target_rank_id( @staticmethod def mnnvl_moe_alltoallv_prepare_without_allgather( expert_ids: torch.Tensor, - scales: torch.Tensor, expert_statics: Optional[torch.Tensor], workspace: torch.Tensor, max_token_count_per_rank: int, @@ -405,8 +408,6 @@ def mnnvl_moe_alltoallv_prepare_without_allgather( top_k: int, ): ( - prepared_local_experts, - prepared_local_scales, local_send_rank_count_cumsum, local_send_rank_indices, local_recv_rank_count_cumsum, @@ -415,7 +416,6 @@ def mnnvl_moe_alltoallv_prepare_without_allgather( gathered_expert_statics, ) = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather( expert_ids, - scales, expert_statics, workspace, max_token_count_per_rank, @@ -440,7 +440,7 @@ def mnnvl_moe_alltoallv_prepare_without_allgather( local_token_allocation_count, ) - return alltoall_info, prepared_local_experts, prepared_local_scales, gathered_expert_statics + return alltoall_info, gathered_expert_statics @staticmethod def mnnvl_moe_expert_static_allgather( @@ -526,31 +526,67 @@ def mnnvl_moe_alltoallv_prepare( @staticmethod def mnnvl_moe_alltoallv( - x: torch.Tensor, + x: Union[torch.Tensor, List[Optional[torch.Tensor]]], alltoall_info: MoEAlltoallInfo, workspace: torch.Tensor, ep_rank: int, ep_size: int, - ): - assert x.dim() == 2, "only 2D tensor supported, please reshape." - output_tensor = torch.empty( - alltoall_info.local_token_allocation_count, - x.shape[1], - dtype=x.dtype, - device=torch.device("cuda"), - ) - torch.ops.trtllm.moe_comm( - x, - alltoall_info.send_rank_count_cumsum, - alltoall_info.send_rank_local_indices, - output_tensor, - alltoall_info.recv_rank_count_cumsum, - alltoall_info.recv_rank_local_indices, - workspace, - ep_rank, - ep_size, - ) - return output_tensor + ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: + # Convert single tensor to list for unified handling + is_single_tensor = not isinstance(x, list) + if is_single_tensor: + assert x.dim() == 2, "only 2D tensor supported, please reshape." + x = [x] + + assert len(x) > 0, "Empty tensor list not supported" + + # Filter out None values + valid_list = [tensor is not None for tensor in x] + valid_tensors = [tensor for tensor in x if tensor is not None] + + if len(valid_tensors) == 0: + # All tensors are None, return list of None + result = [None] * len(x) + else: + first_dim = None + for tensor in valid_tensors: + # Validate dimensions of valid tensors + assert tensor.dim() == 2, "only 2D tensor supported, please reshape." + if first_dim is None: + first_dim = tensor.shape[0] + else: + assert tensor.shape[0] == first_dim, ( + f"All tensors must have the same first dimension, got {tensor.shape[0]} vs {first_dim}" + ) + + # Process only valid tensors + output_tensors = torch.ops.trtllm.moe_comm( + valid_tensors, + alltoall_info.send_rank_count_cumsum, + alltoall_info.send_rank_local_indices, + alltoall_info.recv_rank_count_cumsum, + alltoall_info.recv_rank_local_indices, + workspace, + alltoall_info.local_token_allocation_count, + ep_rank, + ep_size, + ) + + # Restore None positions in output + idx = 0 + result = [] + for is_valid in valid_list: + if is_valid: + result.append(output_tensors[idx]) + idx += 1 + else: + result.append(None) + + # If input was a single tensor, return a single tensor + if is_single_tensor: + result = result[0] + + return result @staticmethod def mnnvl_moe_alltoallv_combine( @@ -563,20 +599,19 @@ def mnnvl_moe_alltoallv_combine( token_count: int, ): assert x.dim() == 2, "2D tensor supported, please reshape." - output_tensor = torch.zeros( - token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda") - ) - torch.ops.trtllm.moe_comm( - x, + output_tensors = torch.ops.trtllm.moe_comm( + [x], alltoall_info.recv_rank_count_cumsum, alltoall_info.recv_rank_local_indices, - output_tensor, alltoall_info.send_rank_count_cumsum, alltoall_info.backward_recv_rank_local_indices, workspace, + token_count * top_k, ep_rank, ep_size, + [True], ) + output_tensor = output_tensors[0] return torch.sum( output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False ) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 098af11fc85..a1ce45b26e2 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -179,71 +179,27 @@ def _( return (input.new_empty(output_shape, dtype=torch.uint8), global_scale.new_empty(scale_shape, dtype=torch.uint8)) - @torch.library.register_fake("trtllm::moe_comm_prepare_indices") - def _( - gathered_target_rank_ids: torch.Tensor, - real_rank_token_count_cum_sum: Optional[torch.Tensor], - max_token_count_per_rank: int, - expert_count: int, - top_k: int, - ep_rank: int, - ep_size: int, - ): - max_send_ranks_per_token = max(ep_size, top_k) - local_gather_indices_shape = (max_token_count_per_rank * ep_size, ) - rank_count_cum_sum_shape = (ep_size, ) - send_rank_local_indices_shape = (max_token_count_per_rank * - max_send_ranks_per_token, ) - recv_rank_local_indices_shape = (max_token_count_per_rank * ep_size, ) - backward_recv_rank_local_indices_shape = (max_token_count_per_rank * - max_send_ranks_per_token, ) - - local_gather_indices = gathered_target_rank_ids.new_empty( - local_gather_indices_shape, dtype=torch.int32) - send_rank_count_cum_sum = gathered_target_rank_ids.new_empty( - rank_count_cum_sum_shape, dtype=torch.int32) - send_rank_local_indices = gathered_target_rank_ids.new_empty( - send_rank_local_indices_shape, dtype=torch.int32) - recv_rank_count_cum_sum = gathered_target_rank_ids.new_empty( - rank_count_cum_sum_shape, dtype=torch.int32) - recv_rank_local_indices = gathered_target_rank_ids.new_empty( - recv_rank_local_indices_shape, dtype=torch.int32) - backward_recv_rank_local_indices = gathered_target_rank_ids.new_empty( - backward_recv_rank_local_indices_shape, dtype=torch.int32) - - return (local_gather_indices, send_rank_count_cum_sum, - send_rank_local_indices, recv_rank_count_cum_sum, - recv_rank_local_indices, backward_recv_rank_local_indices) - - @torch.library.register_fake("trtllm::moe_local_gather") - def _( - recv_rank_cum_sum: torch.Tensor, - local_gather_indices: torch.Tensor, - gathered_expert_ids: torch.Tensor, - gathered_scales: Optional[torch.Tensor], - local_expert_ids: torch.Tensor, - local_scales: Optional[torch.Tensor], - max_token_count_per_rank: int, - expert_count: int, - top_k: int, - ep_rank: int, - ep_size: int, - ): - pass - @torch.library.register_fake("trtllm::moe_comm") def _( - input: torch.Tensor, + inputs: List[torch.Tensor], send_rank_cum_sum: torch.Tensor, send_indices: torch.Tensor, - output: torch.Tensor, recv_rank_cum_sum: torch.Tensor, recv_indices: torch.Tensor, all_workspaces: torch.Tensor, + output_allocation_count: int, ep_rank: int, ep_size: int, + need_zero_output: Optional[List[bool]], ): - pass + outputs = [] + for input_tensor in inputs: + output_tensor = torch.empty( + (output_allocation_count, input_tensor.shape[1]), + dtype=input_tensor.dtype, + device=input_tensor.device) + outputs.append(output_tensor) + return outputs @torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank") def _(ep_size: int): @@ -287,6 +243,12 @@ def _(single_layer_load_balancer_ptr: int, token_selected_experts: torch.Tensor, offset_by_ep_rank: bool): return torch.empty_like(token_selected_experts) + @torch.library.register_fake("trtllm::memset_expert_ids") + def _(experts_ids: torch.Tensor, recv_rank_count_cumsum: torch.Tensor, + max_token_count_per_rank: int, top_k: int, slot_count: int, + ep_size: int): + pass + @torch.library.custom_op("trtllm::group_rms_norm_base", mutates_args=("outputs", )) def group_rms_norm_base( diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index c9b9fa979fe..cf404042497 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -58,8 +58,7 @@ from ..modules.embedding import Embedding from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoEWeightLoadingMode, TRTLLMGenFusedMoE, - create_moe, - moe_load_balancer_set_repeated_for_next_layer) + create_moe) from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -1159,7 +1158,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.num_hidden_layers = self.config.num_hidden_layers assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: - moe_load_balancer_set_repeated_for_next_layer(model_nextn) + pass else: # modify the QuantConfig to support duplicated mtp layers if model_config.quant_config.exclude_modules is not None: diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 56a489c9635..d3d0f0a3042 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -14,6 +14,7 @@ from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding +from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer from ..modules.gated_mlp import GatedMLP from ..modules.linear import (Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig) @@ -340,6 +341,9 @@ def __init__( mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle( ) else model_config.spec_config.num_nextn_predict_layers + moe_load_balancer_set_repeated_for_next_layer( + model_config.spec_config.num_nextn_predict_layers // mtp_num_layers) + self.mtp_layers = nn.ModuleList([ DeepseekV3MTP(model_config, layer_idx + start_layer_idx, model.aux_stream_dict) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 34bb61a7ab0..da90df16bde 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -5,7 +5,6 @@ import torch from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe -from tensorrt_llm.math_utils import pad_up from ...distributed import allgather from ...model_config import ModelConfig @@ -190,8 +189,6 @@ def has_int8_woq_per_channel(self): @cached_property def enable_alltoall(self): return (self.mapping.moe_ep_size > self.routing_method.experts_per_token - and self.routing_method.experts_per_token % 4 == - 0 # alltoall without allgather only supports top_k % 4 == 0 and self.mapping.enable_attention_dp and self.mapping.tp_size > 1 and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1" @@ -353,39 +350,28 @@ def forward_chunk( token_final_scales = torch.ones_like(token_selected_experts, dtype=torch.float32) - # TODO: support alltoall without allgather for top_k % 4 != 0 - assert top_k % 4 == 0, "alltoall without allgather only supports top_k % 4 == 0" assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized" - alltoall_info, token_selected_experts, token_final_scales, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - token_selected_experts, token_final_scales, None, - self.alltoall_prepare_workspace, max_num_token, self.ep_rank, - self.ep_size, self.num_experts, self.num_experts, top_k) - - # Dispatch alltoall (common for both paths) - x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, - self.alltoall_workspace, - self.ep_rank, self.ep_size) + alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_experts, None, self.alltoall_prepare_workspace, + max_num_token, self.ep_rank, self.ep_size, self.num_experts, + self.num_experts, top_k) + if x_sf is not None: x_sf = x_sf.view(x_row, ceil_div(x_col, self.scaling_vector_size)) - # Pad dim[1] to 16 bytes alignment for alltoall - # TODO: Remove this padding if possible - sf_per_16bytes = 16 // x_sf.element_size() - x_sf_col_orig = x_sf.shape[1] - x_sf_col = pad_up(x_sf_col_orig, sf_per_16bytes) - if x_sf_col > x_sf_col_orig: - x_sf = torch.nn.functional.pad( - x_sf, (0, x_sf_col - x_sf_col_orig)) - - x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info, - self.alltoall_workspace, - self.ep_rank, self.ep_size) - x_row = x_sf.shape[0] + # Dispatch x, x_sf, token_selected_experts, token_final_scales in one alltoall kernel + x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_experts, token_final_scales], + alltoall_info, self.alltoall_workspace, self.ep_rank, + self.ep_size) - # TODO: Remove this slicing required by padding if possible - x_sf = x_sf[:, :x_sf_col_orig].contiguous() + torch.ops.trtllm.memset_expert_ids( + token_selected_experts, alltoall_info.recv_rank_count_cumsum, + max_num_token, top_k, self.num_experts, self.ep_size) + if x_sf is not None: + x_row = x_sf.shape[0] x_sf = swizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) elif run_post_quant_allgather: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 9fee27e6c93..22d14a83b55 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -192,18 +192,13 @@ def __init__( self.use_low_precision_combine = (os.environ.get( "TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0") == "1") and qm.has_nvfp4() - # TODO: support alltoall without allgather for top_k % 4 != 0 - self.enable_alltoall_without_allgather = ( - os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER", - "1") == "1" - ) and self.alltoall_method_type == AlltoallMethodType.MNNVL and routing_method.experts_per_token % 4 == 0 + if self.alltoall_method_type == AlltoallMethodType.MNNVL: MnnvlMemory.initialize() self.alltoall_workspace = MnnvlMoe.get_moe_workspaces( model_config.mapping) - if self.enable_alltoall_without_allgather: - self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( - model_config.mapping) + self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace( + model_config.mapping) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: self.deep_ep_buffer = buffer_pool.get_buffer( model_config.mapping) @@ -301,6 +296,9 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: 1) // self.moe_max_num_tokens def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + return True + # Disable alltoall when chunking is used if self.calculate_num_chunks(all_rank_num_tokens) > 1: return False @@ -458,24 +456,23 @@ def forward_chunk( else: tuner_num_tokens = None tuner_top_k = None + alltoall_info = None if use_all_to_all: if self.alltoall_method_type == AlltoallMethodType.MNNVL: if self.enable_dummy_allreduce: self.dummy_allreduce() token_count = x.shape[0] - alltoall_info = None - if is_last_call: + if is_last_call and self.layer_load_balancer is not None and not self.layer_load_balancer.is_static_routing( + ): loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( ) else: loadbalancer_local_statistic_info = None - x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \ - self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens, - x, - token_selected_slots, - token_final_scales, - use_postquant_alltoall, - loadbalancer_local_statistic_info) + token_selected_slots, gathered_loadbalancer_local_statistic_info, alltoall_info = \ + self.alltoall_prepare(all_rank_max_num_tokens, + token_selected_slots, + loadbalancer_local_statistic_info) + if gathered_loadbalancer_local_statistic_info is not None: gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( (self.mapping.moe_ep_size, self.num_experts)) @@ -580,10 +577,15 @@ def forward_chunk( cluster_rank = self.cluster_rank quant_scales = self.quant_scales + if self.alltoall_method_type == AlltoallMethodType.MNNVL: + top_k = self.routing_method.experts_per_token + x, x_sf, token_selected_slots, token_final_scales = self.alltoall_dispatch( + x, x_sf, token_selected_slots, token_final_scales, + all_rank_max_num_tokens, top_k, alltoall_info) + if use_postquant_alltoall: if self.alltoall_method_type == AlltoallMethodType.MNNVL: - x, x_sf = self.alltoall_postquant_dispatch( - x, x_sf, alltoall_info) + pass elif self.alltoall_method_type == AlltoallMethodType.DeepEP: if x_sf is not None: # Adapter between `x_sf` and DeepEP @@ -862,77 +864,34 @@ def split_chunk(split_token_num: int, split_num_chunks: int): self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 return outputs - def alltoall_prepare_maybe_dispatch( - self, all_rank_max_num_tokens: int, x: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: torch.Tensor, use_postquant_alltoall: bool, - local_statistic_tensor: Optional[torch.Tensor]): + def alltoall_prepare(self, all_rank_max_num_tokens: int, + token_selected_slots: torch.Tensor, + local_statistic_tensor: Optional[torch.Tensor]): top_k = self.routing_method.experts_per_token - if self.enable_alltoall_without_allgather: - alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - token_selected_slots, token_final_scales, - local_statistic_tensor, self.alltoall_prepare_workspace, - all_rank_max_num_tokens, self.ep_rank, self.ep_size, - self.num_experts, self.num_slots, top_k) - else: - if all_rank_max_num_tokens > token_selected_slots.shape[0]: - token_selected_slots = torch.nn.functional.pad( - token_selected_slots, - (0, 0, 0, - all_rank_max_num_tokens - token_selected_slots.shape[0]), - 'constant', self.num_slots) - if token_final_scales is not None and all_rank_max_num_tokens > token_final_scales.shape[ - 0]: - token_final_scales = torch.nn.functional.pad( - token_final_scales, - (0, 0, 0, - all_rank_max_num_tokens - token_final_scales.shape[0])) - gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather( - [ - token_selected_slots, token_final_scales, - local_statistic_tensor - ], - self.mapping, - dim=0) - gathered_token_selected_slots = torch.flatten( - gathered_token_selected_slots.contiguous(), - start_dim=0, - end_dim=-2) - if gathered_token_final_scales is not None: - gathered_token_final_scales = torch.flatten( - gathered_token_final_scales.contiguous(), - start_dim=0, - end_dim=-2) - gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id( - gathered_token_selected_slots, self.num_slots, self.ep_size) - alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare( - gathered_target_rank_ids, None, gathered_token_selected_slots, - gathered_token_final_scales, all_rank_max_num_tokens, - self.num_slots, top_k, self.ep_rank, self.ep_size) - - if not use_postquant_alltoall: - assert not isinstance( - x, Fp4QuantizedTensor - ), "pre-quant alltoall doesn't support fp4 tensor" - x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, - self.alltoall_workspace, - self.ep_rank, self.ep_size) - - return x, token_selected_slots, token_final_scales, gathered_local_statistic_tensor, alltoall_info - - def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, - alltoall_info: MoEAlltoallInfo): - x = MnnvlMoe.mnnvl_moe_alltoallv(x, alltoall_info, - self.alltoall_workspace, self.ep_rank, - self.ep_size) - - if x_sf is not None: - x_sf = MnnvlMoe.mnnvl_moe_alltoallv(x_sf, alltoall_info, - self.alltoall_workspace, - self.ep_rank, self.ep_size) - - return x, x_sf + alltoall_info, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + token_selected_slots, local_statistic_tensor, + self.alltoall_prepare_workspace, all_rank_max_num_tokens, + self.ep_rank, self.ep_size, self.num_experts, self.num_slots, top_k) + + return token_selected_slots, gathered_local_statistic_tensor, alltoall_info + + def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor], + token_selected_slots: torch.Tensor, + token_final_scales: Optional[torch.Tensor], + all_rank_max_num_tokens: int, top_k: int, + alltoall_info: MoEAlltoallInfo): + + x, x_sf, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv( + [x, x_sf, token_selected_slots, token_final_scales], alltoall_info, + self.alltoall_workspace, self.ep_rank, self.ep_size) + + torch.ops.trtllm.memset_expert_ids(token_selected_slots, + alltoall_info.recv_rank_count_cumsum, + all_rank_max_num_tokens, top_k, + self.num_slots, self.ep_size) + + return x, x_sf, token_selected_slots, token_final_scales def alltoall_combine(self, final_hidden_states: torch.Tensor, alltoall_info: MoEAlltoallInfo, token_count: int): diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index cc970b452f1..6696198a45e 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -269,8 +269,6 @@ test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mis accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5433545) examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B] SKIP (https://nvbugs/5444636) accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive SKIP (https://nvbugs/5444627) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687) examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instruct] SKIP (https://nvbugs/5447530) examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636) examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4] SKIP (https://nvbugs/5453709) diff --git a/tests/unittest/_torch/thop/test_moe_alltoall.py b/tests/unittest/_torch/thop/test_moe_alltoall.py index e795b68f9e6..e3758f05e9a 100644 --- a/tests/unittest/_torch/thop/test_moe_alltoall.py +++ b/tests/unittest/_torch/thop/test_moe_alltoall.py @@ -52,10 +52,6 @@ def test_moe_alltoall_single_gpu(self, input_entry_count, vector_dim, dtype=dtype, device=torch.device('cuda')) - output_tensor = torch.zeros(output_entry_count, - vector_dim, - dtype=dtype, - device=torch.device('cuda')) send_cumsum = torch.ones( (1, ), dtype=torch.int32, @@ -78,13 +74,18 @@ def test_moe_alltoall_single_gpu(self, input_entry_count, workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1) all_workspaces = torch.zeros(1, - workspace_size, + workspace_size // 8, dtype=torch.uint64, device=torch.device('cuda')) + torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1) + + output_tensors = torch.ops.trtllm.moe_comm([input_tensor], send_cumsum, + send_indices, recv_cumsum, + recv_indices, all_workspaces, + output_entry_count, 0, 1, + [True]) - torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices, - output_tensor, recv_cumsum, recv_indices, - all_workspaces, 0, 1) + output_tensor = output_tensors[0] torch.testing.assert_close(output_tensor, ref_output_tensor, @@ -103,40 +104,43 @@ def do_warmup(self): send_indices = torch.zeros(1, dtype=torch.int32, device=torch.device('cuda')) - output_tensor = torch.zeros(1, - 8, - dtype=torch.float16, - device=torch.device('cuda')) recv_cumsum = torch.ones(1, dtype=torch.int32, device=torch.device('cuda')) recv_indices = torch.zeros(1, dtype=torch.int32, device=torch.device('cuda')) + input_tensors = [input_tensor] workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1) all_workspaces = torch.zeros(1, - workspace_size, + workspace_size // 8, dtype=torch.uint64, device=torch.device('cuda')) - torch.ops.trtllm.moe_comm(input_tensor, send_cumsum, send_indices, - output_tensor, recv_cumsum, recv_indices, - all_workspaces, 0, 1) + _ = torch.ops.trtllm.moe_comm(input_tensors, send_cumsum, send_indices, + recv_cumsum, recv_indices, all_workspaces, + 1, 0, 1, [True]) torch.cuda.synchronize() @parameterized.expand([ - (2, 5, 8, torch.float16), # small input as smoke test - (2, 1, 8, torch.float16), # some ranks have no data to send/recv - (4, 5, 8, torch.float16), # small input with larger world size - (4, 901, 32768, torch.bfloat16), # large input that reuses workspace - (8, 901, 32768, + (2, 5, [4, 4], torch.float16), # small input as smoke test + (2, 1, [8], torch.float16), # some ranks have no data to send/recv + (4, 5, [8], torch.float16), # small input with larger world size + (4, 901, [1472, 46, 4, + 4], torch.float16), # large input that reuses workspace + (4, 5, [2944], torch.bfloat16), # large input that reuses workspace + (8, 901, [ + 32768, + ], torch.float16), # large input that reuses workspace, larger world size ( - 8, 16384, 128, torch.float16 + 8, 16384, [ + 128, + ], torch.float16 ), # large input count with small vector dim that requires more indices per fifo ]) def test_moe_alltoall_multi_rank_single_gpu(self, world_size, input_entry_per_rank, - vector_dim, dtype): + vector_dims, dtype): torch.cuda.set_device(0) max_world_size = 8 assert world_size <= max_world_size, f"should run with world_size at most {max_world_size}" @@ -148,27 +152,32 @@ def test_moe_alltoall_multi_rank_single_gpu(self, world_size, torch.ops.trtllm.set_moe_max_usable_sm_count(max_sm_count) has_setup_max_sm_count = True - # Create a random input tensor - input_tensor = torch.randn(input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device('cuda')) - output_tensor = torch.zeros(input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device('cuda')) - ref_output_tensor = torch.zeros(input_entry_per_rank * world_size, - vector_dim, - dtype=dtype, - device=torch.device('cuda')) + tensor_count = len(vector_dims) + input_tensors = [] + ref_output_tensors = [] + for vector_dim in vector_dims: + input_tensors.append( + torch.randn(input_entry_per_rank * world_size, + vector_dim, + dtype=dtype, + device=torch.device('cuda'))) + ref_output_tensors.append( + torch.zeros(input_entry_per_rank * world_size, + vector_dim, + dtype=dtype, + device=torch.device('cuda'))) + target_rank_ids = torch.randint(0, world_size, (input_entry_per_rank * world_size, ), dtype=torch.int32, device=torch.device('cuda')) - input_tensors_all_ranks = list( - torch.split(input_tensor, input_entry_per_rank)) + input_tensors_all_ranks = [] + for i in range(tensor_count): + input_tensors_all_ranks.append( + list(torch.split(input_tensors[i], input_entry_per_rank))) + target_rank_ids_all_ranks = list( torch.split(target_rank_ids, input_entry_per_rank)) @@ -210,12 +219,9 @@ def test_moe_alltoall_multi_rank_single_gpu(self, world_size, recv_ids_all_ranks = [] recv_cumsum_all_ranks = [] - output_tensors_all_ranks = [] - total_recv_all_ranks_cpu = [] output_indice_offset = 0 - output_start_current_rank = 0 # each rank do compute based on other ranks' send counts to get how to receive data from other ranks. for rank in range(world_size): local_recv_counts = torch.zeros(world_size, @@ -227,18 +233,15 @@ def test_moe_alltoall_multi_rank_single_gpu(self, world_size, local_recv_count_pair = local_recv_counts[other_rank].cpu( ).item() send_rank_start_end = send_start_end_all_ranks[other_rank][rank] - ref_output_tensor[output_indice_offset:output_indice_offset + local_recv_count_pair] = \ - input_tensors_all_ranks[other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]] + for i in range(tensor_count): + ref_output_tensors[i][output_indice_offset:output_indice_offset + local_recv_count_pair] = \ + input_tensors_all_ranks[i][other_rank][send_ids_all_ranks[other_rank][send_rank_start_end[0]:send_rank_start_end[1]]] output_indice_offset += local_recv_count_pair local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32) recv_cumsum_all_ranks.append(local_recv_cumsum) total_recv_count = local_recv_cumsum[-1].cpu() total_recv_all_ranks_cpu.append(total_recv_count) - output_tensors_all_ranks.append(output_tensor[ - output_start_current_rank:output_start_current_rank + - total_recv_count]) - output_start_current_rank += total_recv_count local_recv_ids = torch.arange(total_recv_count, dtype=torch.int32, device=torch.device('cuda')) @@ -251,9 +254,12 @@ def test_moe_alltoall_multi_rank_single_gpu(self, world_size, workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank( world_size) all_workspaces = torch.zeros(world_size, - workspace_size, + workspace_size // 8, dtype=torch.uint64, device=torch.device('cuda')) + for i in range(world_size): + torch.ops.trtllm.moe_initialize_workspace(all_workspaces, i, + world_size) # do one warmup for each rank to avoid possible synchronization at first launch. for rank in range(world_size): @@ -262,212 +268,141 @@ def test_moe_alltoall_multi_rank_single_gpu(self, world_size, torch.cuda.synchronize() + # Store output tensors from each rank + output_tensors_all_ranks = [] + # do alltoall in parallel for rank in range(world_size): + input_tensors_this_rank = [ + input_tensors_all_ranks[i][rank] for i in range(tensor_count) + ] with torch.cuda.stream(cuda_streams_all_ranks[rank]): - torch.ops.trtllm.moe_comm( - input_tensors_all_ranks[rank], send_cumsum_all_ranks[rank], - send_ids_all_ranks[rank], output_tensors_all_ranks[rank], - recv_cumsum_all_ranks[rank], recv_ids_all_ranks[rank], - all_workspaces, rank, world_size) + output_tensors_this_rank = torch.ops.trtllm.moe_comm( + input_tensors_this_rank, send_cumsum_all_ranks[rank], + send_ids_all_ranks[rank], recv_cumsum_all_ranks[rank], + recv_ids_all_ranks[rank], all_workspaces, + input_entry_per_rank * world_size, rank, world_size) + output_tensors_all_ranks.append(output_tensors_this_rank) + for rank in range(world_size): cuda_streams_all_ranks[rank].synchronize() - torch.testing.assert_close(output_tensor, - ref_output_tensor, - atol=1e-5, - rtol=1e-5) + # Reconstruct the full output tensors by concatenating results from all ranks + for i in range(tensor_count): + # Collect the actual received data from each rank (trim to actual recv count) + actual_output_parts = [] + for rank in range(world_size): + total_recv_count = total_recv_all_ranks_cpu[rank].item() + # Each rank returns tensor with size [input_entry_per_rank * world_size, vector_dim] + # but only the first total_recv_count entries are valid + actual_output_parts.append( + output_tensors_all_ranks[rank][i][:total_recv_count]) - @parameterized.expand([ - (0, 8, 256, 4, 3, False), - (0, 8, 256, 4, 3, True), - (1, 8, 256, 4, 3, False), - (1, 8, 256, 4, 3, True), - (1, 4, 256, 8, 3, False), - (1, 4, 256, 8, 3, True), - (7, 8, 256, 8, 1025, False), - (7, 8, 256, 8, 1025, True), - (7, 64, 1024, 32, 1029, False), - (7, 64, 1024, 32, 1029, True), - ]) - def test_moe_alltoall_prepare_indices( - self, ep_rank: int, ep_size: int, expert_count: int, top_k: int, - max_token_count_per_rank: int, - use_real_rank_token_count_cumsum: bool): - torch.cuda.set_device(0) - gathered_target_rank_ids = torch.randint( - 0, - ep_size, (ep_size * max_token_count_per_rank, top_k), - dtype=torch.int32, - device=torch.device('cuda')) - real_rank_token_count_cumsum = None - if use_real_rank_token_count_cumsum: - real_rank_token_count_cumsum = torch.randint( - 0, - max_token_count_per_rank + 1, (ep_size, ), - dtype=torch.int32, - device=torch.device('cuda')) - real_rank_token_count_cumsum = torch.cumsum( - real_rank_token_count_cumsum, dim=0).to(torch.int32) + # Concatenate all ranks' outputs to form the complete result + actual_output = torch.cat(actual_output_parts, dim=0) + torch.testing.assert_close(actual_output, + ref_output_tensors[i], + atol=1e-5, + rtol=1e-5) - def generate_references(): - gathered_target_rank_ids_cpu_lists = gathered_target_rank_ids.cpu( - ).tolist() - if use_real_rank_token_count_cumsum: - real_rank_token_count_cumsum_cpu_lists = real_rank_token_count_cumsum.cpu( - ).tolist() - else: - real_rank_token_count_cumsum_cpu_lists = [ - (i + 1) * max_token_count_per_rank for i in range(ep_size) - ] - rank_token_start = 0 - ref_local_gather_indices_cpu_lists = [] - ref_recv_rank_count_cumsum_cpu_lists = [0] * ep_size - ref_recv_rank_local_indices_cpu_lists = [] - ref_send_rank_count_cumsum_cpu_lists = [0] * ep_size - ref_send_rank_local_indices_cpu_lists = [] - ref_backward_recv_rank_local_indices_cpu_lists = [] - total_recv_count = 0 - for rank in range(ep_size): - rank_token_end = real_rank_token_count_cumsum_cpu_lists[rank] - for token_id in range(rank_token_start, rank_token_end): - if ep_rank in gathered_target_rank_ids_cpu_lists[token_id]: - ref_local_gather_indices_cpu_lists.append(token_id) - ref_recv_rank_local_indices_cpu_lists.append( - total_recv_count) - total_recv_count += 1 - ref_recv_rank_count_cumsum_cpu_lists[rank] = total_recv_count - if rank == ep_rank: - total_send_count = 0 - for target_rank in range(ep_size): - for token_id in range(rank_token_start, rank_token_end): - local_token_id = token_id - rank_token_start - if target_rank in gathered_target_rank_ids_cpu_lists[ - token_id]: - pos = gathered_target_rank_ids_cpu_lists[ - token_id].index(target_rank) - ref_send_rank_local_indices_cpu_lists.append( - local_token_id) - ref_backward_recv_rank_local_indices_cpu_lists.append( - local_token_id * top_k + pos) - total_send_count += 1 - ref_send_rank_count_cumsum_cpu_lists[ - target_rank] = total_send_count - rank_token_start = rank_token_end - ref_local_gather_indices = torch.IntTensor( - ref_local_gather_indices_cpu_lists).cuda() - ref_send_rank_count_cumsum = torch.IntTensor( - ref_send_rank_count_cumsum_cpu_lists).cuda() - ref_send_rank_local_indices = torch.IntTensor( - ref_send_rank_local_indices_cpu_lists).cuda() - ref_recv_rank_count_cumsum = torch.IntTensor( - ref_recv_rank_count_cumsum_cpu_lists).cuda() - ref_recv_rank_local_indices = torch.IntTensor( - ref_recv_rank_local_indices_cpu_lists).cuda() - ref_backward_recv_rank_local_indices = torch.IntTensor( - ref_backward_recv_rank_local_indices_cpu_lists).cuda() - return ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices - - ref_local_gather_indices, ref_send_rank_count_cumsum, ref_send_rank_local_indices, ref_recv_rank_count_cumsum, ref_recv_rank_local_indices, ref_backward_recv_rank_local_indices = generate_references( - ) - local_gather_indices, send_rank_count_cumsum, send_rank_local_indices, recv_rank_count_cumsum, recv_rank_local_indices, backward_recv_rank_local_indices = \ - torch.ops.trtllm.moe_comm_prepare_indices(gathered_target_rank_ids, real_rank_token_count_cumsum, max_token_count_per_rank, expert_count, top_k, ep_rank, ep_size) +class TestMoeAlltoAllFP8SingleGPU(unittest.TestCase): - assert torch.equal( - local_gather_indices[:torch.numel(ref_local_gather_indices)], - ref_local_gather_indices) - assert torch.equal( - send_rank_count_cumsum[:torch.numel(ref_send_rank_count_cumsum)], - ref_send_rank_count_cumsum) - assert torch.equal( - send_rank_local_indices[:torch.numel(ref_send_rank_local_indices)], - ref_send_rank_local_indices) - assert torch.equal( - recv_rank_count_cumsum[:torch.numel(ref_recv_rank_count_cumsum)], - ref_recv_rank_count_cumsum) - assert torch.equal( - recv_rank_local_indices[:torch.numel(ref_recv_rank_local_indices)], - ref_recv_rank_local_indices) - assert torch.equal( - backward_recv_rank_local_indices[:torch.numel( - ref_backward_recv_rank_local_indices)], - ref_backward_recv_rank_local_indices) + def setUp(self): + torch.manual_seed(0x1234) + tllm.logger.set_level('error') - @parameterized.expand([ - (0, 8, 256, 4, 3), - (1, 8, 256, 4, 3), - (7, 8, 256, 4, 3), - (7, 8, 256, 8, 32), - (7, 8, 256, 32, 10), - (7, 8, 1024, 32, 127), - (7, 64, 1024, 32, 1029), - (9, 64, 1024, 3, 1029), - ]) - def test_moe_local_gather(self, ep_rank: int, ep_size: int, - expert_count: int, top_k: int, - max_token_count_per_rank: int): + def test_moe_alltoall_fp8_with_indices(self): + """Test fp8 alltoall with properly constructed indices""" torch.cuda.set_device(0) - rank_token_count_cumsum = torch.randint(0, - max_token_count_per_rank + 1, - (ep_size, ), - dtype=torch.int32, - device=torch.device('cuda')) - rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, - dim=0).to(torch.int32) - local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item() - local_max_token_count = max_token_count_per_rank * ep_size - local_gather_indices = torch.randint(0, - max_token_count_per_rank * ep_size, - (local_max_token_count, ), - dtype=torch.int32, - device=torch.device('cuda')) - - gathered_expert_ids = torch.randint( - 0, - expert_count, (max_token_count_per_rank * ep_size, top_k), - dtype=torch.int32, - device=torch.device('cuda')) - gathered_scales = torch.rand( - (max_token_count_per_rank * ep_size, top_k), - dtype=torch.float32, - device=torch.device('cuda')) - - ref_local_expert_ids = torch.zeros(local_max_token_count, - top_k, - dtype=torch.int32, - device=torch.device('cuda')) - ref_local_scales = torch.zeros(local_max_token_count, - top_k, - dtype=torch.float32, - device=torch.device('cuda')) - - # compute reference - ref_local_expert_ids += expert_count - valid_local_gather_indices = local_gather_indices[:local_token_count] - ref_local_expert_ids[:local_token_count] = gathered_expert_ids[ - valid_local_gather_indices] - ref_local_scales[:local_token_count] = gathered_scales[ - valid_local_gather_indices] - - local_expert_ids = torch.empty(local_max_token_count, - top_k, - dtype=torch.int32, - device=torch.device('cuda')) - local_scales = torch.empty(local_max_token_count, - top_k, - dtype=torch.float32, - device=torch.device('cuda')) - torch.ops.trtllm.moe_local_gather(rank_token_count_cumsum, - local_gather_indices, - gathered_expert_ids, gathered_scales, - local_expert_ids, local_scales, - max_token_count_per_rank, - expert_count, top_k, ep_rank, ep_size) + # Match dimensions from the error + input_entry_count = 16384 + output_entry_count = 16384 + vector_dim = 2944 + sf_vector_dim = 92 # Scaling factor dimension from error + send_recv_count = 1000 # Number of entries to send/receive - assert torch.equal(local_expert_ids, ref_local_expert_ids) - assert torch.equal(local_scales, ref_local_scales) + # Create input tensors - first as float16, then convert + input_tensor_fp16 = torch.randn(input_entry_count, + vector_dim, + dtype=torch.float16, + device='cuda') + input_tensor_fp8 = input_tensor_fp16.to(torch.float8_e4m3fn) + + # Scaling factor tensor + input_sf_tensor = torch.randint(1, + 255, (input_entry_count, sf_vector_dim), + dtype=torch.uint8, + device='cuda') + + # Expert selection tensors + input_experts = torch.randint(0, + 64, (input_entry_count, 4), + dtype=torch.int32, + device='cuda') + input_scales = torch.rand(input_entry_count, + 4, + dtype=torch.float32, + device='cuda') + + # Construct send/recv indices + send_cumsum = torch.tensor([send_recv_count], + dtype=torch.int32, + device='cuda') + recv_cumsum = torch.tensor([send_recv_count], + dtype=torch.int32, + device='cuda') + + # Random indices for sending + send_indices = torch.randperm(input_entry_count, + dtype=torch.int32, + device='cuda')[:send_recv_count] + recv_indices = torch.randperm(output_entry_count, + dtype=torch.int32, + device='cuda')[:send_recv_count] + + # Create workspace + workspace_size = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(1) + all_workspaces = torch.zeros(1, + workspace_size // 8, + dtype=torch.uint64, + device='cuda') + torch.ops.trtllm.moe_initialize_workspace(all_workspaces, 0, 1) + + print(f"Test configuration:") + print(f" Input entries: {input_entry_count}") + print(f" Vector dim: {vector_dim}") + print(f" SF vector dim: {sf_vector_dim}") + print(f" Send/recv count: {send_recv_count}") + print(f" FP8 tensor shape: {input_tensor_fp8.shape}") + print(f" SF tensor shape: {input_sf_tensor.shape}") + + try: + # Test with all 4 tensors + output_tensor_fp8, output_sf_tensor, output_experts, output_scales = \ + torch.ops.trtllm.moe_comm([ + input_tensor_fp8, input_sf_tensor, input_experts, input_scales + ], send_cumsum, send_indices, recv_cumsum, recv_indices, all_workspaces, output_entry_count, 0, 1) + + torch.cuda.synchronize() + print("FP8 alltoall test PASSED!") + + # Verify outputs + print(f"\nOutput verification:") + print(f" Output FP8 shape: {output_tensor_fp8.shape}") + print(f" Output SF shape: {output_sf_tensor.shape}") + print( + f" Non-zero FP8 elements: {(output_tensor_fp8 != 0).sum().item()}" + ) + print( + f" Non-zero SF elements: {(output_sf_tensor != 0).sum().item()}" + ) + + except Exception as e: + print(f"FP8 alltoall test FAILED: {e}") + print(f"Error type: {type(e)}") + raise @parameterized.expand([ (0, 2, 16, 20, 8, 512), @@ -489,7 +424,6 @@ def test_moe_alltoall_prepare(self, ep_rank: int, ep_size: int, cpu_expert_ids_all_ranks_lists = [] cpu_token_count_lists = [] - cpu_scales_all_ranks_lists = [] for _ in range(ep_size): token_count = torch.randint(max_token_count_per_rank // 2, max_token_count_per_rank + 1, (1, ), @@ -505,12 +439,6 @@ def test_moe_alltoall_prepare(self, ep_rank: int, ep_size: int, dtype=torch.int32, device=torch.device('cpu'))) - cpu_scales_all_ranks_lists.append( - torch.zeros(token_count, - top_k, - dtype=torch.float32, - device=torch.device('cpu')) + 0.5) - cpu_token_count_lists.append(token_count) def compute_target_rank(expert_id): @@ -519,7 +447,6 @@ def compute_target_rank(expert_id): def generate_references(): ref_prepared_local_expert_ids = [] - ref_prepared_local_scales = [] ref_local_send_rank_count_cumsum = [0] * ep_size ref_local_recv_rank_count_cumsum = [0] * ep_size ref_local_recv_rank_indices = [] @@ -580,16 +507,13 @@ def generate_references(): for pos in range(top_k): expert_id = int( cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) - sf = cpu_scales_all_ranks_lists[rank][token_id][pos] target_rank_id = compute_target_rank(expert_id) if target_rank_id == ep_rank: if not token_is_received: token_is_received = True ref_prepared_local_expert_ids.append( [slot_count] * top_k) - ref_prepared_local_scales.append([0.0] * top_k) ref_prepared_local_expert_ids[-1][pos] = expert_id - ref_prepared_local_scales[-1][pos] = sf if token_is_received: ref_local_recv_rank_indices.append( total_recv_token_count) @@ -599,9 +523,9 @@ def generate_references(): rank] = current_recv_token_count if rank == 0 else ref_local_recv_rank_count_cumsum[ rank - 1] + current_recv_token_count - return ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count + return ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count - ref_prepared_local_expert_ids, ref_prepared_local_scales, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references( + ref_prepared_local_expert_ids, ref_local_send_rank_count_cumsum, ref_local_send_rank_indices, ref_local_recv_rank_count_cumsum, ref_local_recv_rank_indices, ref_local_backward_send_rank_indices, total_recv_token_count = generate_references( ) cpu_experter_count_lists = [] @@ -615,10 +539,6 @@ def generate_references(): expert_ids_all_ranks = [ cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size) ] - #scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() - scales_all_ranks = [ - cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size) - ] experter_count_lists = [ cpu_experter_count_lists[i].cuda() for i in range(ep_size) @@ -637,30 +557,18 @@ def generate_references(): stream = torch.cuda.Stream() with torch.cuda.stream(stream): torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather( - expert_ids_all_ranks[0], scales_all_ranks[0], - experter_count_lists[0], all_workspaces, - max_token_count_per_rank, 0, 1, expert_count, slot_count, top_k) + expert_ids_all_ranks[0], experter_count_lists[0], + all_workspaces, max_token_count_per_rank, 0, 1, expert_count, + slot_count, top_k) stream.wait_stream(torch.cuda.current_stream()) # Make torch alloc tensor to avoid cuda sync - prepared_local_experts = [] - prepared_local_scales = [] local_send_rank_count_cumsum = [] local_send_rank_indices = [] local_recv_rank_count_cumsum = [] local_recv_rank_indices = [] backward_local_recv_rank_indices = [] for _ in range(ep_size): - prepared_local_experts.append( - torch.empty(max_token_count_per_rank * ep_size, - top_k, - dtype=torch.int32, - device=torch.device('cuda'))) - prepared_local_scales.append( - torch.empty(max_token_count_per_rank * ep_size, - top_k, - dtype=torch.float32, - device=torch.device('cuda'))) local_send_rank_count_cumsum.append( torch.empty(ep_size, dtype=torch.int32, @@ -676,8 +584,6 @@ def generate_references(): backward_local_recv_rank_indices.append( torch.empty(0, dtype=torch.int32, device=torch.device('cuda'))) - prepared_local_experts = [] - prepared_local_scales = [] local_send_rank_count_cumsum = [] local_send_rank_indices = [] local_recv_rank_count_cumsum = [] @@ -694,35 +600,19 @@ def generate_references(): for rank in range(ep_size): with torch.cuda.stream(cuda_streams_all_ranks[rank]): if rank == ep_rank: - prepared_local_experts, prepared_local_scales, local_send_rank_count_cumsum, \ + local_send_rank_count_cumsum, \ local_send_rank_indices, local_recv_rank_count_cumsum, local_recv_rank_indices, \ backward_local_recv_rank_indices, gathered_expert_statics\ - = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], scales_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank, + = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(expert_ids_all_ranks[rank], experter_count_lists[rank], all_workspaces, max_token_count_per_rank, rank, ep_size, expert_count, slot_count, top_k) else: torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather( - expert_ids_all_ranks[rank], scales_all_ranks[rank], - experter_count_lists[rank], all_workspaces, - max_token_count_per_rank, rank, ep_size, expert_count, - slot_count, top_k) + expert_ids_all_ranks[rank], experter_count_lists[rank], + all_workspaces, max_token_count_per_rank, rank, ep_size, + expert_count, slot_count, top_k) for rank in range(ep_size): cuda_streams_all_ranks[rank].synchronize() - prepared_local_experts_cpu = prepared_local_experts[: - total_recv_token_count].cpu( - ) - prepared_local_scales_cpu = prepared_local_scales[: - total_recv_token_count].cpu( - ) - for i in range(total_recv_token_count): - for j in range(top_k): - expert_id = int(prepared_local_experts_cpu[i][j]) - assert 0 <= expert_id and expert_id <= slot_count - if expert_id < slot_count: - assert compute_target_rank(expert_id) == ep_rank - scale = float(prepared_local_scales_cpu[i][j]) - assert scale > 1e-6 - gathered_expert_statics_cpu = gathered_expert_statics.cpu() for rank in range(ep_size): for i in range(expert_count):