Skip to content

Commit fd8f821

Browse files
iotamudeltawunhuang
authored andcommitted
Make CAR ROCm 6.1 compatible. (ROCm#137)
* remove scoping * while there fix a typo * while there remove unused variable
1 parent 2b7a776 commit fd8f821

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

csrc/custom_all_reduce.cuh

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,17 @@ DINLINE O downcast(array_t<float, O::size> val) {
145145
template <int ngpus>
146146
#ifdef USE_ROCM
147147
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
148-
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
149148
if (threadIdx.x < ngpus) {
150-
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
151-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
149+
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
150+
__ATOMIC_RELAXED);
152151
// simultaneously write to the corresponding flag of all ranks.
153152
// Latency = 1 p2p write
154-
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
155-
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
153+
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1,
154+
__ATOMIC_RELAXED);
156155
__atomic_thread_fence(__ATOMIC_ACQ_REL);
157156
// wait until we got true from all ranks
158-
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
159-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
157+
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
158+
__ATOMIC_RELAXED);
160159
}
161160
__syncthreads();
162161
}
@@ -190,16 +189,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
190189
// the memory model.
191190
if (threadIdx.x < ngpus) {
192191
// reset flag for next time
193-
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
194-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
192+
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
193+
__ATOMIC_RELAXED);
195194
// simultaneously write to the corresponding flag of all ranks.
196195
// Latency = 1 p2p write
197-
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
198-
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
196+
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
197+
__ATOMIC_RELAXED);
199198
__atomic_thread_fence(__ATOMIC_ACQ_REL);
200199
// wait until we got true from all ranks
201-
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
202-
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
200+
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
201+
__ATOMIC_RELAXED));
203202
}
204203
if constexpr (!final_sync) __syncthreads();
205204
}

csrc/custom_all_reduce_test.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ int main(int argc, char** argv) {
330330
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
331331
// }
332332
// }
333-
#ifdef USE _ROCM
333+
#ifdef USE_ROCM
334334
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
335335
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
336336
}

0 commit comments

Comments
 (0)