@@ -279,9 +279,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
279279 }
280280
281281 std::vector<Stmt> seq;
282- std::vector<Var> shared_buffer_vars (size);
283- std::vector<Buffer> shared_bufs (size);
284- std::vector<Buffer> local_bufs;
282+ std::vector<Buffer> new_alloc_bufs;
285283 //
286284 // This is an optimization. For small reduction sizes, it may be beneficial
287285 // for a single warp to performance the entire reduction. No trips to shared
@@ -300,130 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
300298 // the final reduction result to the proper location.
301299 //
302300 if (is_warp_reduction (types, group_extent, reduce_extent, contiguous_reduce_extent)) {
303- ICHECK_LE (reduce_extent, warp_size_) << " not a warp reduction" ;
304- //
305- // This is the index to the reduction variable, one reduction
306- // variable per warp. Local scope seems easier to reason without
307- // relying on a pattern match pass to fix it later.
308- Array<PrimExpr> zero_indices = {0 };
309-
310- for (size_t idx = 0 ; idx < size; ++idx) {
311- Array<PrimExpr> shape = {1 };
312-
313- Buffer buffer = decl_buffer (shape, types[idx], " red_buf" + std::to_string (idx));
314- Var buffer_var = buffer->data ;
315-
316- shared_buffer_vars[idx] = buffer_var;
317- shared_bufs[idx] = buffer;
318-
319- PrimExpr pred = const_true (types[idx].lanes ());
320- seq.emplace_back (BufferStore (shared_bufs[idx], values[idx], zero_indices));
321-
322- // Uses a local variable to store the shuffled data. Later
323- // on, an allocation will be built for this local variable.
324- local_bufs.push_back (decl_buffer (shape, types[idx], " t" + std::to_string (idx)));
325- }
326-
327- // The mask for this reducer, as this reducer may sit inside
328- // a divergent control flow. Here it uses a variable to cache the current
329- // active channels.
330- //
301+ std::vector<PrimExpr> reduce_results;
331302 DataType mask_dtype = DataType::UInt (32 );
332- Buffer mask_buffer = decl_buffer ({ 1 }, mask_dtype, " mask " );
333- {
334- PrimExpr mask = Call (mask_dtype, builtin::tvm_warp_activemask (), {});
335- if (group_extent > 1 ) {
336- mask = mask & ( make_const (mask_dtype, ( 1ll << reduce_extent) - 1 )
337- << (reduce_extent * cast (mask_dtype, group_index)));
303+ PrimExpr mask = Call (mask_dtype, builtin::tvm_warp_activemask (), {} );
304+
305+ if (reduce_extent <= warp_size_) {
306+ if (group_extent > 1 && reduce_extent < warp_size_ ) {
307+ mask = mask &
308+ ((( 1 << reduce_extent) - 1 ) << (reduce_extent * cast (mask_dtype, group_index)));
338309 }
339- seq. emplace_back ( BufferStore (mask_buffer, mask, zero_indices));
340- // Push the buffer description. Later this will have an
341- // allocation built for it.
342- local_bufs. push_back (mask_buffer) ;
343- }
310+ std::tie (reduce_results, new_alloc_bufs) = MakeWarpAllreduce (
311+ values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
312+ } else {
313+ int n_warps = reduce_extent / warp_size_ ;
314+ std::vector<Buffer> local_bufs;
344315
345- // Emit reductions within a warp.
346- int start_offset = 1 ;
347- while (start_offset * 2 < reduce_extent) {
348- start_offset *= 2 ;
349- }
350- for (int offset = start_offset; offset > 0 ; offset /= 2 ) {
351- // Load reduction values, no synchronization needed.
352- Array<PrimExpr> a, b;
316+ // 1. Create the staging buffer in shared memory.
317+ std::vector<Buffer> staging_shared_bufs;
318+ staging_shared_bufs.reserve (size);
353319 for (size_t i = 0 ; i < size; ++i) {
354- Buffer shared_buf = shared_bufs[i];
355- BufferLoad val (shared_buf, zero_indices);
356- ICHECK_EQ (val->dtype , types[i]);
357- a.push_back (val);
358-
359- // __shfl_*sync calls shall not appear in if_then_else expressions
360- // as this is causing extra divergency. E.g.
361- //
362- // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
363- //
364- // behaves differently from
365- //
366- // int t = __shfl_sync(mask, v1, 0);
367- // v1 = (v2 < v3) ? v3 : t;
368- //
369- // The former may cause dead lock as there is a divergent
370- // branch with a warp sync call inside.
371- //
372- PrimExpr other = WarpShuffle (builtin::tvm_warp_shuffle_down (), mask_buffer, val, offset);
373- Buffer local_buf = local_bufs[i];
374- Stmt s = BufferStore (local_buf, other, zero_indices);
375- seq.push_back (s);
376-
377- BufferLoad load = BufferLoad (local_buf, zero_indices);
378- ICHECK_EQ (load->dtype , types[i]);
379- b.push_back (load);
320+ Buffer staging_shared_buf = decl_buffer (
321+ /* shape=*/ {make_const (reduce_index->dtype , n_warps * group_extent)},
322+ /* dtype=*/ buffers[i]->dtype , /* name=*/ " red_buf_staging" , /* storage_scope=*/ " shared" );
323+ staging_shared_bufs.push_back (staging_shared_buf);
324+ new_alloc_bufs.push_back (staging_shared_buf);
380325 }
381326
382- // Do reductions.
383- Array<PrimExpr> ret = (*combiner)(a, b);
327+ // 2. First round of allreduce.
328+ std::tie (reduce_results, local_bufs) = MakeWarpAllreduce (
329+ values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq);
330+ new_alloc_bufs.insert (new_alloc_bufs.end (), local_bufs.begin (), local_bufs.end ());
384331
385- // Store the reduction result to itself.
386- std::vector<Stmt> stores (size);
332+ // 3. Write allreduce results to staging buffer.
333+ std::vector<Stmt> write_staging_buf;
334+ write_staging_buf.reserve (size);
387335 for (size_t i = 0 ; i < size; ++i) {
388- Buffer buf = shared_bufs[i];
389- stores[i] = BufferStore (buf, ret[i], zero_indices);
336+ new_alloc_bufs.push_back (Downcast<BufferLoad>(reduce_results[i])->buffer );
337+ write_staging_buf.push_back (BufferStore (
338+ /* buffer=*/ staging_shared_bufs[i],
339+ /* value=*/ reduce_results[i],
340+ /* indices=*/ {group_index * n_warps + floordiv (reduce_index, warp_size_)}));
390341 }
342+ PrimExpr cond = floormod (reduce_index, warp_size_) == make_const (reduce_index->dtype , 0 );
343+ seq.push_back (IfThenElse (cond, SeqStmt::Flatten (write_staging_buf)));
344+ seq.push_back (SyncThread (" shared" ));
391345
392- // During the sub-warp reduction, values from inactive threads could be read,
393- // which is an undefined behavior according to the cuda document.
394- //
395- // In practice, the return value are usually 0, which does no harm to sum reduction.
396- // However, the result can be incorrect in max or prod reduction.
397- // Therefore an additional range check has to be performed to ensure the correctness.
398- if (offset * 2 > reduce_extent) {
399- PrimExpr cond = reduce_index + offset < reduce_extent;
400- seq.push_back (IfThenElse (cond, SeqStmt::Flatten (stores)));
401- } else {
402- seq.push_back (SeqStmt::Flatten (stores));
346+ // 4. Load staging buffer.
347+ // Second round of allreduce.
348+ for (size_t i = 0 ; i < size; ++i) {
349+ values[i] = BufferLoad (/* buffer=*/ staging_shared_bufs[i], /* indices=*/ {reduce_index});
403350 }
351+ if (n_warps < warp_size_) {
352+ mask = mask & (((1 << n_warps) - 1 ) << group_index);
353+ }
354+ std::tie (reduce_results, local_bufs) = MakeWarpAllreduce (
355+ values, types, combiner, reduce_index, n_warps, group_index,
356+ /* mask=*/ mask,
357+ /* predicate=*/ reduce_index < make_const (reduce_index->dtype , group_extent * n_warps),
358+ &seq);
359+ new_alloc_bufs.insert (new_alloc_bufs.end (), local_bufs.begin (), local_bufs.end ());
404360 }
405361
406- // Broadcast the reduction result from lane 0 to all other lanes.
407- // This avoids to emit predicated stores, as all threads are
408- // uniformly writing the same result.
409- //
410- for (size_t i = 0 ; i < size; ++i) {
411- Buffer buf = shared_bufs[i];
412- PrimExpr val = BufferLoad (buf, zero_indices);
413- ICHECK_EQ (val->dtype , types[i]);
414- PrimExpr splat =
415- WarpShuffle (builtin::tvm_warp_shuffle (), mask_buffer, val, reduce_extent * group_index);
416- seq.push_back (BufferStore (buf, splat, zero_indices));
417- }
418-
419- // Update existing allocations.
362+ // Write back allreduce results and update existing allocations.
420363 for (size_t i = 0 ; i < size; ++i) {
421364 ICHECK (!load_remap_.count (buffers[i]->data .get ()));
422365 PrimExpr pred = const_true (types[i].lanes ());
423- Buffer buf = shared_bufs [i];
424- PrimExpr val = BufferLoad (buf, zero_indices );
425- ICHECK_EQ (val-> dtype , types [i]) ;
426- load_remap_[buffers[i]-> data . get ()] = val;
366+ Buffer buf = Downcast<BufferLoad>(reduce_results [i])-> buffer ;
367+ ICHECK_EQ (reduce_results[i]-> dtype , types[i] );
368+ load_remap_[buffers[i]-> data . get ()] = reduce_results [i];
369+
427370 Array<PrimExpr> extents{PrimExpr (1 )};
428371 auto node = Allocate (buf->data , types[i], extents, pred, Evaluate (0 ));
429372 alloc_remap_[buffers[i]->data .get ()] = node;
@@ -432,6 +375,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
432375 warp_allocs_.insert (node.get ());
433376 }
434377 } else {
378+ std::vector<Buffer> shared_bufs (size);
435379 if (reduce_extent == 1 ) {
436380 // special case, no reduction is needed.
437381 std::vector<Stmt> stores;
@@ -447,7 +391,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
447391 Buffer buffer = decl_buffer ({1 }, types[idx], " red_buf" + std::to_string (idx));
448392
449393 shared_bufs[idx] = buffer;
450- shared_buffer_vars[idx] = buffer->data ;
451394
452395 PrimExpr pred = const_true (types[idx].lanes ());
453396 seq.emplace_back (BufferStore (shared_bufs[idx], values[idx],
@@ -473,14 +416,153 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
473416
474417 // Fix all local allocations as all statements are built.
475418 Stmt body = SeqStmt::Flatten (seq);
476- for (Buffer buf : local_bufs ) {
419+ for (Buffer buf : new_alloc_bufs ) {
477420 body = Allocate (buf->data , buf->dtype , buf->shape , const_true (buf->dtype .lanes ()), body);
478- new_storage_scopes_[buf->data .get ()] = " local" ;
421+ String scope = buf.scope ();
422+ if (buf.scope () != " shared" ) {
423+ new_storage_scopes_[buf->data .get ()] = " local" ;
424+ }
479425 }
480426
481427 return body;
482428 }
483429
430+ std::pair<std::vector<PrimExpr>, std::vector<Buffer>> MakeWarpAllreduce (
431+ std::vector<PrimExpr> src_values, //
432+ std::vector<DataType> dtypes, //
433+ const CommReducerNode* combiner, //
434+ PrimExpr reduce_index, int reduce_extent, //
435+ PrimExpr group_index, //
436+ PrimExpr mask, Optional<PrimExpr> predicate, //
437+ std::vector<Stmt>* seq) {
438+ int n_buffers = src_values.size ();
439+
440+ std::vector<Buffer> shared_bufs;
441+ std::vector<Buffer> local_bufs;
442+ shared_bufs.reserve (n_buffers);
443+
444+ // This is the index to the reduction variable, one reduction
445+ // variable per warp. Local scope seems easier to reason without
446+ // relying on a pattern match pass to fix it later.
447+ Array<PrimExpr> zero_indices = {0 };
448+
449+ std::vector<Stmt> load_values;
450+ load_values.reserve (n_buffers);
451+ for (int idx = 0 ; idx < n_buffers; ++idx) {
452+ Array<PrimExpr> shape = {1 };
453+
454+ Buffer buffer = decl_buffer (shape, dtypes[idx], " red_buf" + std::to_string (idx));
455+ Var buffer_var = buffer->data ;
456+
457+ shared_bufs.push_back (buffer);
458+
459+ PrimExpr pred = const_true (dtypes[idx].lanes ());
460+ load_values.push_back (BufferStore (shared_bufs[idx], src_values[idx], zero_indices));
461+
462+ // Uses a local variable to store the shuffled data. Later
463+ // on, an allocation will be built for this local variable.
464+ local_bufs.push_back (decl_buffer (shape, dtypes[idx], " t" + std::to_string (idx)));
465+ }
466+
467+ if (predicate.defined ()) {
468+ seq->push_back (IfThenElse (predicate.value (), SeqStmt::Flatten (load_values)));
469+ } else {
470+ seq->insert (seq->end (), load_values.begin (), load_values.end ());
471+ }
472+
473+ // The mask for this reducer, as this reducer may sit inside
474+ // a divergent control flow. Here it uses a variable to cache the current
475+ // active channels.
476+ Buffer mask_buffer = decl_buffer ({1 }, mask->dtype , " mask" );
477+ {
478+ seq->emplace_back (BufferStore (mask_buffer, mask, zero_indices));
479+ // Push the buffer description. Later this will have an
480+ // allocation built for it.
481+ local_bufs.push_back (mask_buffer);
482+ }
483+
484+ // Emit reductions within a warp.
485+ int start_offset = 1 ;
486+ while (start_offset * 2 < reduce_extent) {
487+ start_offset *= 2 ;
488+ }
489+ for (int offset = start_offset; offset > 0 ; offset /= 2 ) {
490+ // Load reduction values, no synchronization needed.
491+ Array<PrimExpr> a, b;
492+ for (int i = 0 ; i < n_buffers; ++i) {
493+ Buffer shared_buf = shared_bufs[i];
494+ BufferLoad val (shared_buf, zero_indices);
495+ ICHECK_EQ (val->dtype , dtypes[i]);
496+ a.push_back (val);
497+
498+ // __shfl_*sync calls shall not appear in if_then_else expressions
499+ // as this is causing extra divergency. E.g.
500+ //
501+ // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
502+ //
503+ // behaves differently from
504+ //
505+ // int t = __shfl_sync(mask, v1, 0);
506+ // v1 = (v2 < v3) ? v3 : t;
507+ //
508+ // The former may cause dead lock as there is a divergent
509+ // branch with a warp sync call inside.
510+ PrimExpr other = WarpShuffle (builtin::tvm_warp_shuffle_down (), mask_buffer, val, offset);
511+ Buffer local_buf = local_bufs[i];
512+ Stmt s = BufferStore (local_buf, other, zero_indices);
513+ seq->push_back (s);
514+
515+ BufferLoad load = BufferLoad (local_buf, zero_indices);
516+ ICHECK_EQ (load->dtype , dtypes[i]);
517+ b.push_back (load);
518+ }
519+
520+ // Do reductions.
521+ Array<PrimExpr> ret = (*combiner)(a, b);
522+
523+ // Store the reduction result to itself.
524+ std::vector<Stmt> stores;
525+ stores.reserve (n_buffers);
526+ for (int i = 0 ; i < n_buffers; ++i) {
527+ Buffer buf = shared_bufs[i];
528+ stores.push_back (BufferStore (buf, ret[i], zero_indices));
529+ }
530+
531+ // During the sub-warp reduction, values from inactive threads could be read,
532+ // which is an undefined behavior according to the cuda document.
533+ //
534+ // In practice, the return value are usually 0, which does no harm to sum reduction.
535+ // However, the result can be incorrect in max or prod reduction.
536+ // Therefore an additional range check has to be performed to ensure the correctness.
537+ if (offset * 2 > reduce_extent) {
538+ PrimExpr cond = reduce_index + offset < reduce_extent;
539+ seq->push_back (IfThenElse (cond, SeqStmt::Flatten (stores)));
540+ } else {
541+ seq->push_back (SeqStmt::Flatten (stores));
542+ }
543+ }
544+
545+ // Broadcast the reduction result from lane 0 to all other lanes.
546+ // This avoids to emit predicated stores, as all threads are
547+ // uniformly writing the same result.
548+ for (int i = 0 ; i < n_buffers; ++i) {
549+ Buffer buf = shared_bufs[i];
550+ PrimExpr val = BufferLoad (buf, zero_indices);
551+ ICHECK_EQ (val->dtype , dtypes[i]);
552+ PrimExpr splat =
553+ WarpShuffle (builtin::tvm_warp_shuffle (), mask_buffer, val, reduce_extent * group_index);
554+ seq->push_back (BufferStore (buf, splat, zero_indices));
555+ }
556+
557+ std::vector<PrimExpr> reduce_results;
558+ reduce_results.reserve (n_buffers);
559+ for (int i = 0 ; i < n_buffers; ++i) {
560+ reduce_results.push_back (BufferLoad (shared_bufs[i], zero_indices));
561+ }
562+
563+ return {reduce_results, local_bufs};
564+ }
565+
484566 // make allreduce.
485567 Stmt MakeBufAllreduce (const CommReducerNode* combiner, const std::vector<DataType>& types,
486568 const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
@@ -676,8 +758,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
676758 if (reduce_extent == 1 ) {
677759 return false ; // no need to warp reduce
678760 } else {
679- if (warp_size_ % reduce_extent == 0 ) {
680- return true ; // warp size is multiple of reduce extent
761+ if (warp_size_ % reduce_extent == 0 || reduce_extent % warp_size_ == 0 ) {
762+ return true ; // warp size is multiple or factor of reduce extent
681763 } else {
682764 return group_extent == 1 && reduce_extent <= warp_size_;
683765 }
0 commit comments