@@ -43,12 +43,7 @@ struct __align__(16) RankData { const void* ptrs[8]; };
4343struct __align__ (16 ) RankData { const void * __restrict__ ptrs[8 ]; };
4444#endif
4545
46- struct __align__ (16 ) RankSignals {
47- #ifndef USE_ROCM
48- volatile
49- #endif
50- Signal* signals[8 ];
51- };
46+ struct __align__ (16 ) RankSignals { volatile Signal* signals[8 ]; };
5247
5348// like std::array, but aligned
5449template <typename T, int sz>
@@ -141,28 +136,25 @@ DINLINE O downcast(array_t<float, O::size> val) {
141136// This function is meant to be used as the first synchronization in the all
142137// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
143138// prior memory accesses. Note: volatile writes will not be reordered against
144- // other volatile writes (CUDA-only) .
139+ // other volatile writes.
145140template <int ngpus>
141+ DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
142+ int rank) {
146143#ifdef USE_ROCM
147- DINLINE void start_sync ( const RankSignals& sg, Signal* self_sg, int rank) {
144+ uint32_t flag = self_sg-> _flag [ blockIdx . x ] + 1 ;
148145 if (threadIdx .x < ngpus) {
149- __atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
150- __ATOMIC_RELAXED);
151146 // simultaneously write to the corresponding flag of all ranks.
152147 // Latency = 1 p2p write
153- __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], 1 ,
148+ __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], flag ,
154149 __ATOMIC_RELAXED);
155- __atomic_thread_fence (__ATOMIC_ACQ_REL);
156150 // wait until we got true from all ranks
157- while (! __atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
158- __ATOMIC_RELAXED );
151+ while (__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
152+ __ATOMIC_RELAXED) < flag );
159153 }
160154 __syncthreads ();
161- }
155+ // use one thread to update flag
156+ if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
162157#else
163- DINLINE void start_sync (const RankSignals& sg, volatile Signal* self_sg,
164- int rank) {
165- uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
166158 if (threadIdx .x < ngpus) {
167159 // reset flag for next time
168160 self_sg->end [blockIdx .x ][threadIdx .x ] = 0 ;
@@ -173,38 +165,36 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
173165 while (!self_sg->start [blockIdx .x ][threadIdx .x ]);
174166 }
175167 __syncthreads ();
176- }
177168#endif
169+ }
178170
179171// This function is meant to be used as the second or the final synchronization
180172// barrier in the all reduce kernel. If it's the final synchronization barrier,
181173// we don't need to make any visibility guarantees for prior memory accesses.
182174template <int ngpus, bool final_sync = false >
175+ DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
176+ int rank) {
183177#ifdef USE_ROCM
184- DINLINE void end_sync (const RankSignals& sg, Signal* self_sg, int rank) {
185178 __syncthreads ();
186179 // eliminate the case that prior writes are not visible after signals become
187180 // visible. Note that I did not managed to make this happen through a lot of
188181 // testing. Might be the case that hardware provides stronger guarantee than
189182 // the memory model.
183+ uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
190184 if (threadIdx .x < ngpus) {
191- // reset flag for next time
192- __atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
193- __ATOMIC_RELAXED);
194185 // simultaneously write to the corresponding flag of all ranks.
195186 // Latency = 1 p2p write
196- __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
197- __ATOMIC_RELAXED);
198- __atomic_thread_fence (__ATOMIC_ACQ_REL);
187+ __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], flag,
188+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
199189 // wait until we got true from all ranks
200- while (!__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
201- __ATOMIC_RELAXED));
190+ while (__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
191+ final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
192+ flag);
202193 }
203- if constexpr (!final_sync) __syncthreads ();
204- }
194+ __syncthreads ();
195+ // use one thread to update flag
196+ if (threadIdx .x == 0 ) self_sg->_flag [blockIdx .x ] = flag;
205197#else
206- DINLINE void end_sync (const RankSignals& sg, volatile Signal* self_sg,
207- int rank) {
208198 __syncthreads ();
209199 // eliminate the case that prior writes are not visible after signals become
210200 // visible. Note that I did not managed to make this happen through a lot of
@@ -221,8 +211,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
221211 while (!self_sg->end [blockIdx .x ][threadIdx .x ]);
222212 }
223213 if constexpr (!final_sync) __syncthreads ();
224- }
225214#endif
215+ }
226216
227217template <typename P, int ngpus, typename A>
228218DINLINE P packed_reduce (const P* ptrs[], int idx) {
@@ -237,11 +227,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
237227template <typename T, int ngpus>
238228__global__ void __launch_bounds__ (512 , 1 )
239229 cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
240- #ifndef USE_ROCM
241- volatile
242- #endif
243- Signal* self_sg,
244- T* __restrict__ result, int rank, int size) {
230+ volatile Signal* self_sg, T* __restrict__ result,
231+ int rank, int size) {
245232 using P = typename packed_t <T>::P;
246233 using A = typename packed_t <T>::A;
247234 // note: we don't reorder the address so the accumulation order is the same
@@ -257,22 +244,15 @@ __global__ void __launch_bounds__(512, 1)
257244}
258245
259246template <typename P>
260- DINLINE P* get_tmp_buf (
261- #ifndef USE_ROCM
262- volatile
263- #endif
264- Signal* sg) {
247+ DINLINE P* get_tmp_buf (volatile Signal* sg) {
265248 return (P*)(((Signal*)sg) + 1 );
266249}
267250
268251template <typename T, int ngpus>
269252__global__ void __launch_bounds__ (512 , 1 )
270253 cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
271- #ifndef USE_ROCM
272- volatile
273- #endif
274- Signal* self_sg,
275- T* __restrict__ result, int rank, int size) {
254+ volatile Signal* self_sg, T* __restrict__ result,
255+ int rank, int size) {
276256 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
277257 int stride = gridDim .x * blockDim .x ;
278258 using P = typename packed_t <T>::P;
@@ -475,41 +455,37 @@ class CustomAllreduce {
475455 */
476456 template <typename T>
477457 void allreduce (cudaStream_t stream, T* input, T* output, int size,
478- #ifdef USE_ROCM
479- int threads = 512 , int block_limit = 18 ){
480- #else
481458 int threads = 512 , int block_limit = 36 ) {
482- #endif
483- auto d = packed_t <T>::P::size;
484- if (size % d != 0 )
485- throw std::runtime_error (
486- " custom allreduce currently requires input length to be multiple "
487- " of " +
488- std::to_string (d));
489- if (block_limit > kMaxBlocks )
490- throw std::runtime_error (" max supported block limit is " +
491- std::to_string (kMaxBlocks ) + " . Got " +
492- std::to_string (block_limit));
493-
494- RankData* ptrs;
495- cudaStreamCaptureStatus status;
496- CUDACHECK (cudaStreamIsCapturing (stream, &status));
497- if (status == cudaStreamCaptureStatusActive) {
498- ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
499- graph_unreg_buffers_.push_back (input);
500- } else {
501- auto it = buffers_.find (input);
502- if (it == buffers_.end ())
459+ auto d = packed_t <T>::P::size;
460+ if (size % d != 0 )
503461 throw std::runtime_error (
504- " buffer address " +
505- std::to_string (reinterpret_cast <uint64_t >(input)) +
506- " is not registered!" );
507- ptrs = it->second ;
508- }
462+ " custom allreduce currently requires input length to be multiple "
463+ " of " +
464+ std::to_string (d));
465+ if (block_limit > kMaxBlocks )
466+ throw std::runtime_error (" max supported block limit is " +
467+ std::to_string (kMaxBlocks ) + " . Got " +
468+ std::to_string (block_limit));
469+
470+ RankData* ptrs;
471+ cudaStreamCaptureStatus status;
472+ CUDACHECK (cudaStreamIsCapturing (stream, &status));
473+ if (status == cudaStreamCaptureStatusActive) {
474+ ptrs = d_rank_data_base_ + graph_unreg_buffers_.size ();
475+ graph_unreg_buffers_.push_back (input);
476+ } else {
477+ auto it = buffers_.find (input);
478+ if (it == buffers_.end ())
479+ throw std::runtime_error (
480+ " buffer address " +
481+ std::to_string (reinterpret_cast <uint64_t >(input)) +
482+ " is not registered!" );
483+ ptrs = it->second ;
484+ }
509485
510- size /= d;
511- auto bytes = size * sizeof (typename packed_t <T>::P);
512- int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
486+ size /= d;
487+ auto bytes = size * sizeof (typename packed_t <T>::P);
488+ int blocks = std::min (block_limit, (size + threads - 1 ) / threads);
513489#define KL (ngpus, name ) \
514490 name<T, ngpus><<<blocks, threads, 0 , stream>>> (ptrs, sg_, self_sg_, output, \
515491 rank_, size);
@@ -528,27 +504,27 @@ class CustomAllreduce {
528504 break ; \
529505 }
530506
531- switch (world_size_) {
532- REDUCE_CASE (2 )
533- REDUCE_CASE (4 )
534- REDUCE_CASE (6 )
535- REDUCE_CASE (8 )
536- default :
537- throw std::runtime_error (
538- " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
539- " gpus = " +
540- std::to_string (world_size_));
541- }
507+ switch (world_size_) {
508+ REDUCE_CASE (2 )
509+ REDUCE_CASE (4 )
510+ REDUCE_CASE (6 )
511+ REDUCE_CASE (8 )
512+ default :
513+ throw std::runtime_error (
514+ " custom allreduce only supports num gpus in (2,4,6,8). Actual num "
515+ " gpus = " +
516+ std::to_string (world_size_));
517+ }
542518#undef REDUCE_CASE
543519#undef KL
544- }
520+ }
545521
546- ~CustomAllreduce () {
547- for (auto [_, ptr] : ipc_handles_) {
548- CUDACHECK (cudaIpcCloseMemHandle (ptr));
522+ ~CustomAllreduce () {
523+ for (auto [_, ptr] : ipc_handles_) {
524+ CUDACHECK (cudaIpcCloseMemHandle (ptr));
525+ }
549526 }
550- }
551- }; // namespace vllm
527+ };
552528/* *
553529 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
554530 a template instantiation:
0 commit comments