@@ -145,18 +145,17 @@ DINLINE O downcast(array_t<float, O::size> val) {
145145template <int ngpus>
146146#ifdef USE_ROCM
147147DINLINE void start_sync (const RankSignals& sg, Signal* self_sg, int rank) {
148- uint32_t flag = self_sg->_flag [blockIdx .x ] + 1 ;
149148 if (threadIdx .x < ngpus) {
150- __scoped_atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
151- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE );
149+ __atomic_store_n (&self_sg->end [blockIdx .x ][threadIdx .x ], 0 ,
150+ __ATOMIC_RELAXED);
152151 // simultaneously write to the corresponding flag of all ranks.
153152 // Latency = 1 p2p write
154- __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank],
155- 1 , __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM );
153+ __atomic_store_n (&sg.signals [threadIdx .x ]->start [blockIdx .x ][rank], 1 ,
154+ __ATOMIC_RELAXED);
156155 __atomic_thread_fence (__ATOMIC_ACQ_REL);
157156 // wait until we got true from all ranks
158- while (!__scoped_atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
159- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) );
157+ while (!__atomic_load_n (&self_sg->start [blockIdx .x ][threadIdx .x ],
158+ __ATOMIC_RELAXED);
160159 }
161160 __syncthreads ();
162161}
@@ -190,16 +189,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
190189 // the memory model.
191190 if (threadIdx .x < ngpus) {
192191 // reset flag for next time
193- __scoped_atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
194- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE );
192+ __atomic_store_n (&self_sg->start [blockIdx .x ][threadIdx .x ], 0 ,
193+ __ATOMIC_RELAXED);
195194 // simultaneously write to the corresponding flag of all ranks.
196195 // Latency = 1 p2p write
197- __scoped_atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
198- __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM );
196+ __atomic_store_n (&sg.signals [threadIdx .x ]->end [blockIdx .x ][rank], 1 ,
197+ __ATOMIC_RELAXED);
199198 __atomic_thread_fence (__ATOMIC_ACQ_REL);
200199 // wait until we got true from all ranks
201- while (!__scoped_atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
202- __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE ));
200+ while (!__atomic_load_n (&self_sg->end [blockIdx .x ][threadIdx .x ],
201+ __ATOMIC_RELAXED));
203202 }
204203 if constexpr (!final_sync) __syncthreads ();
205204}
0 commit comments