Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,17 @@ DINLINE O downcast(array_t<float, O::size> val) {
template <int ngpus>
#ifdef USE_ROCM
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
__scoped_atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
1, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED);
}
__syncthreads();
}
Expand Down Expand Up @@ -190,16 +189,16 @@ DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) {
// the memory model.
if (threadIdx.x < ngpus) {
// reset flag for next time
__scoped_atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE);
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0,
__ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
__ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1,
__ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE));
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED));
}
if constexpr (!final_sync) __syncthreads();
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ int main(int argc, char** argv) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
#ifdef USE _ROCM
#ifdef USE_ROCM
for (int sz = 512; sz <= (8 << 22); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test);
}
Expand Down