@@ -319,19 +319,19 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
319319 }
320320}
321321
322- template <typename STEP_COMMUNICATOR_TYPE >
322+ template <typename PipelineConfig >
323323class PacketPipeline
324324{
325325public:
326326 __device__ __inline__ PacketPipeline (
327- void * bufferBase, STEP_COMMUNICATOR_TYPE * stepCommunicator, int * sharedNewStepPtr, bool isSender)
327+ void * bufferBase, StepCommunicatorBase * stepCommunicator, int * sharedNewStepPtr, bool isSender)
328328 : bufferBase(bufferBase)
329329 , stepCommunicator(stepCommunicator)
330330 , shared_new_step(sharedNewStepPtr)
331331 {
332332 step = 0 ;
333333 needRelease = false ;
334- packetId = isSender ? 0 : PACKET_PER_STEP - 1 ;
334+ packetId = isSender ? 0 : PipelineConfig:: PACKET_PER_STEP - 1 ;
335335 }
336336
337337 __device__ __forceinline__ void * getFirstSendPacket ()
@@ -343,9 +343,10 @@ public:
343343 {
344344
345345 packetId++;
346- if (packetId < PACKET_PER_STEP)
346+ if (packetId < PipelineConfig:: PACKET_PER_STEP)
347347 {
348- return acquireNewStep ? bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE
348+ return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
349+ + packetId * PipelineConfig::PACKET_SIZE
349350 : nullptr ;
350351 }
351352
@@ -365,7 +366,7 @@ public:
365366 {
366367 step = *(shared_new_step);
367368 packetId = 0 ;
368- return bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
369+ return bufferBase + step * PipelineConfig:: PACKET_SIZE * PipelineConfig:: PACKET_PER_STEP;
369370 }
370371
371372 return nullptr ;
@@ -382,9 +383,10 @@ public:
382383 __device__ __inline__ void * getNewRecvPacket ()
383384 {
384385 packetId++;
385- if (packetId < PACKET_PER_STEP)
386+ if (packetId < PipelineConfig:: PACKET_PER_STEP)
386387 {
387- return bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE;
388+ return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
389+ + packetId * PipelineConfig::PACKET_SIZE;
388390 }
389391
390392 __syncthreads ();
@@ -401,7 +403,7 @@ public:
401403 __syncthreads ();
402404 packetId = 0 ;
403405 step = *(shared_new_step);
404- void * packetPtr = bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
406+ void * packetPtr = bufferBase + step * PipelineConfig:: PACKET_SIZE * PipelineConfig:: PACKET_PER_STEP;
405407
406408 return packetPtr;
407409 }
@@ -415,14 +417,14 @@ public:
415417 }
416418
417419 void * bufferBase;
418- STEP_COMMUNICATOR_TYPE * stepCommunicator;
420+ StepCommunicatorBase * stepCommunicator;
419421 int step;
420422 int packetId;
421423 bool needRelease;
422424 int * shared_new_step;
423425};
424426
425- template <typename STEP_COMMUNICATOR_TYPE >
427+ template <typename PipelineConfig, typename ExpertType, typename ScaleType >
426428__global__ void allToAllMetadataDevice (int * sendExperts, int * recvExperts, float * sendScales, float * recvScales,
427429 int * localExpertStatics, int * gatheredExpertStatics, MoeCommWorkspace workspace, int * sendCountsCumsum,
428430 int * localSendIndice, int * recvCountsCumsum, int * localRecvIndice, int tokenCount, int maxTokenCountPerRank,
@@ -431,22 +433,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
431433 bool isSender = (blockIdx .y == 0 );
432434 int targetRankId = blockIdx .x ;
433435 int slotCountPerRank = slotCount / rankCount;
434- int groupSize = topK / UNIT_SIZE;
435- int groupId = threadIdx .x % groupSize;
436+ int groupSize = topK / PipelineConfig::UNIT_SIZE;
436437
437438 __shared__ int sharedNewStep;
438- __align__ (16 ) int experts[UNIT_SIZE];
439- __align__ (16 ) float scales[UNIT_SIZE];
439+ __align__ (16 ) int experts[PipelineConfig:: UNIT_SIZE];
440+ __align__ (16 ) float scales[PipelineConfig:: UNIT_SIZE];
440441
441442 uint8_t * bufferBase = (uint8_t *) (workspace.getFifoBasePtr (isSender, rankId, targetRankId, 0 , 1 ));
442- STEP_COMMUNICATOR_TYPE stepCommunicator (workspace.getFifoConnInfo (isSender, rankId, targetRankId, 0 , rankCount, 1 ));
443- PacketPipeline<STEP_COMMUNICATOR_TYPE > pipeline (bufferBase, &stepCommunicator, &sharedNewStep, isSender);
443+ StepCommunicatorBase stepCommunicator (workspace.getFifoConnInfo (isSender, rankId, targetRankId, 0 , rankCount, 1 ));
444+ PacketPipeline<PipelineConfig > pipeline (bufferBase, &stepCommunicator, &sharedNewStep, isSender);
444445
445446 if (isSender)
446447 {
447448 int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1 );
448449 int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
449- int unitCount = sendTokenCount * topK / UNIT_SIZE;
450+ int unitCount = sendTokenCount * topK / PipelineConfig:: UNIT_SIZE;
450451
451452 void * packPtr = pipeline.getFirstSendPacket ();
452453 int indexBase = 0 ;
@@ -457,13 +458,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
457458 if (threadIdx .x < UNIT_PER_ITER)
458459 {
459460 int index = indexBase + threadIdx .x ;
461+ int groupId = index % groupSize;
460462 if (index < unitCount)
461463 {
462464 int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
463- *((int4 *) (experts)) = *(int4 *) (sendExperts + tokenId * topK + groupId * UNIT_SIZE);
465+ *((ExpertType*) (experts))
466+ = *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
464467
465468#pragma unroll
466- for (int j = 0 ; j < UNIT_SIZE; j++)
469+ for (int j = 0 ; j < PipelineConfig:: UNIT_SIZE; j++)
467470 {
468471 int expertId = experts[j];
469472 if (expertId / slotCountPerRank != targetRankId)
@@ -472,14 +475,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
472475 }
473476 }
474477
475- int * expertsPtr = (int *) (packPtr) + threadIdx .x * UNIT_SIZE;
476- *((int4 *) (expertsPtr)) = *((int4 *) (experts));
478+ int * expertsPtr = (int *) (packPtr) + threadIdx .x * PipelineConfig:: UNIT_SIZE;
479+ *((ExpertType *) (expertsPtr)) = *((ExpertType *) (experts));
477480 if (sendScales != nullptr )
478481 {
479- *((float4 *) (scales)) = *(float4 *) (sendScales + tokenId * topK + groupId * UNIT_SIZE);
480- float * scaleBasePtr = (float *) (packPtr + SCALE_OFFSET);
481- float * scalesPtr = (float *) (scaleBasePtr) + threadIdx .x * UNIT_SIZE;
482- *((float4 *) (scalesPtr)) = *((float4 *) (scales));
482+ *((ScaleType*) (scales))
483+ = *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
484+ float * scaleBasePtr = (float *) (packPtr + PipelineConfig::SCALE_OFFSET);
485+ float * scalesPtr = (float *) (scaleBasePtr) + threadIdx .x * PipelineConfig::UNIT_SIZE;
486+ *((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
483487 }
484488 }
485489 }
@@ -488,7 +492,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
488492 int staticCopyIdx = threadIdx .x - UNIT_PER_ITER;
489493 if (staticCopyBase + staticCopyIdx * 4 < expertCount)
490494 {
491- int4 * staticBasePtr = (int4 *) (packPtr + STATIC_COPY_OFFSET);
495+ int4 * staticBasePtr = (int4 *) (packPtr + PipelineConfig:: STATIC_COPY_OFFSET);
492496 int4 staticData = *(int4 *) (localExpertStatics + staticCopyBase + staticCopyIdx * 4 );
493497 *(staticBasePtr + staticCopyIdx) = staticData;
494498 }
@@ -521,18 +525,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
521525 if (threadIdx .x < packetUnitCount)
522526 {
523527 int tokenId = baseCumsum + (unitIdBase + threadIdx .x ) / groupSize;
524- int * expertsPtr = (int *) (packetPtr) + threadIdx .x * UNIT_SIZE;
525- *((int4 *) (experts)) = *((int4 *) (expertsPtr));
526- int4 * dstExpertsPtr = (int4 *) (recvExperts + tokenId * topK + groupId * UNIT_SIZE);
527- *dstExpertsPtr = *((int4 *) (experts));
528+ int groupId = (unitIdBase + threadIdx .x ) % groupSize;
529+ int * expertsPtr = (int *) (packetPtr) + threadIdx .x * PipelineConfig::UNIT_SIZE;
530+ *((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
531+ ExpertType* dstExpertsPtr
532+ = (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
533+ *dstExpertsPtr = *((ExpertType*) (experts));
528534
529535 if (recvScales != nullptr )
530536 {
531- float * scaleBasePtr = (float *) (packetPtr + SCALE_OFFSET);
532- float * scalesPtr = scaleBasePtr + threadIdx .x * UNIT_SIZE;
533- *((float4 *) (scales)) = *((float4 *) (scalesPtr));
534- float4 * dstScalesPtr = (float4 *) (recvScales + tokenId * topK + groupId * UNIT_SIZE);
535- *dstScalesPtr = *((float4 *) (scales));
537+ float * scaleBasePtr = (float *) (packetPtr + PipelineConfig::SCALE_OFFSET);
538+ float * scalesPtr = scaleBasePtr + threadIdx .x * PipelineConfig::UNIT_SIZE;
539+ *((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
540+ ScaleType* dstScalesPtr
541+ = (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
542+ *dstScalesPtr = *((ScaleType*) (scales));
536543 }
537544 }
538545 }
@@ -541,7 +548,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
541548 int staticCopyIdx = threadIdx .x - UNIT_PER_ITER;
542549 if (staticCopyBase + staticCopyIdx * 4 < expertCount)
543550 {
544- int4 * staticBasePtr = (int4 *) (packetPtr + STATIC_COPY_OFFSET);
551+ int4 * staticBasePtr = (int4 *) (packetPtr + PipelineConfig:: STATIC_COPY_OFFSET);
545552 int4 staticData = *(staticBasePtr + staticCopyIdx);
546553 *(int4 *) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4 )
547554 = staticData;
@@ -630,10 +637,28 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
630637 dim3 block (block_size);
631638 dim3 grid (rankCount, 2 );
632639
633- allToAllMetadataDevice<StepCommunicatorBase><<<grid, block, 0 , stream>>> (sendExperts, recvExperts, sendScales,
634- recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, localSendIndice,
635- recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, rankId,
636- rankCount);
640+ if (topK % 4 == 0 )
641+ {
642+ using PipelineConfig = PipelineConfig<4 , 16 >;
643+ static_assert (
644+ PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
645+ " FIFO size is too small" );
646+ allToAllMetadataDevice<PipelineConfig, int4 , float4 ><<<grid, block, 0 , stream>>> (sendExperts, recvExperts,
647+ sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
648+ localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
649+ slotCount, rankId, rankCount);
650+ }
651+ else
652+ {
653+ using PipelineConfig = PipelineConfig<1 , 64 >;
654+ static_assert (
655+ PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
656+ " FIFO size is too small" );
657+ allToAllMetadataDevice<PipelineConfig, int , float ><<<grid, block, 0 , stream>>> (sendExperts, recvExperts,
658+ sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
659+ localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
660+ slotCount, rankId, rankCount);
661+ }
637662
638663 int smCount = tensorrt_llm::common::getMultiProcessorCount ();
639664 memsetExpertIdsDevice<<<smCount, 256 , 0 , stream>>> (
@@ -642,7 +667,7 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
642667
643668size_t getMoePrepareWorkspaceSize (int epSize)
644669{
645- return (STEP_DEPTH * PACKET_PER_STEP * PACKET_SIZE + StepCommunicatorBase::META_SIZE) * epSize;
670+ return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
646671}
647672
648673} // namespace moe_prepare
0 commit comments