Skip to content

Commit 3f16761

Browse files
committed
[TIR] ThreadAllreduce warp-level primitive support with multi-warp
This PR enhances the implementation of the LowerThreadAllreduce pass. Prior to this PR, for CUDA backend we will leverage warp-level primitives only when * the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or * the number of reducing threads is less then 32, and equals the reduction extent. Under the requirement above, for reductions that have large number of reducing threads (e.g., reducing over 128, 256 or larger number or threads), the generated code is inefficient. This PR improves the LowerThreadAllreduce pass, so that we now generate more efficient CUDA code in such cases, when the number of reducing threads is a multiple of warp size, with the help of warp-level primitives. Specifically, in such cases, we first reducing 32 elements within each warp, getting the results of each warp stored in shared memory. We then trigger a second round of warp-level primitive reduction within the first warp, and get the final reduction results. In addition to using warp-level primitives, by doing this we also reduce the size of the shared memory. For example, even when reducing over 1024 threads, we now only require shared memory of size 32, compared with 1024 prior to this PR. Tests are added to ensure correctness.
1 parent 9af8efc commit 3f16761

File tree

3 files changed

+590
-122
lines changed

3 files changed

+590
-122
lines changed

python/tvm/tir/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def tvm_storage_sync(storage_scope):
616616
call : PrimExpr
617617
The call expression.
618618
"""
619-
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
619+
return call_intrin("int32", "tir.tvm_storage_sync", storage_scope)
620620

621621

622622
def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 198 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
279279
}
280280

281281
std::vector<Stmt> seq;
282-
std::vector<Var> shared_buffer_vars(size);
283-
std::vector<Buffer> shared_bufs(size);
284-
std::vector<Buffer> local_bufs;
282+
std::vector<Buffer> new_alloc_bufs;
285283
//
286284
// This is an optimization. For small reduction sizes, it may be beneficial
287285
// for a single warp to performance the entire reduction. No trips to shared
@@ -300,130 +298,75 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
300298
// the final reduction result to the proper location.
301299
//
302300
if (is_warp_reduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
303-
ICHECK_LE(reduce_extent, warp_size_) << "not a warp reduction";
304-
//
305-
// This is the index to the reduction variable, one reduction
306-
// variable per warp. Local scope seems easier to reason without
307-
// relying on a pattern match pass to fix it later.
308-
Array<PrimExpr> zero_indices = {0};
309-
310-
for (size_t idx = 0; idx < size; ++idx) {
311-
Array<PrimExpr> shape = {1};
312-
313-
Buffer buffer = decl_buffer(shape, types[idx], "red_buf" + std::to_string(idx));
314-
Var buffer_var = buffer->data;
315-
316-
shared_buffer_vars[idx] = buffer_var;
317-
shared_bufs[idx] = buffer;
318-
319-
PrimExpr pred = const_true(types[idx].lanes());
320-
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx], zero_indices));
321-
322-
// Uses a local variable to store the shuffled data. Later
323-
// on, an allocation will be built for this local variable.
324-
local_bufs.push_back(decl_buffer(shape, types[idx], "t" + std::to_string(idx)));
325-
}
326-
327-
// The mask for this reducer, as this reducer may sit inside
328-
// a divergent control flow. Here it uses a variable to cache the current
329-
// active channels.
330-
//
301+
std::vector<PrimExpr> reduce_results;
331302
DataType mask_dtype = DataType::UInt(32);
332-
Buffer mask_buffer = decl_buffer({1}, mask_dtype, "mask");
333-
{
334-
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
335-
if (group_extent > 1) {
336-
mask = mask & (make_const(mask_dtype, (1ll << reduce_extent) - 1)
337-
<< (reduce_extent * cast(mask_dtype, group_index)));
303+
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
304+
305+
if (reduce_extent <= warp_size_) {
306+
if (group_extent > 1 && reduce_extent < warp_size_) {
307+
mask = mask &
308+
(((1 << reduce_extent) - 1) << (reduce_extent * cast(mask_dtype, group_index)));
338309
}
339-
seq.emplace_back(BufferStore(mask_buffer, mask, zero_indices));
340-
// Push the buffer description. Later this will have an
341-
// allocation built for it.
342-
local_bufs.push_back(mask_buffer);
343-
}
310+
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
311+
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
312+
} else {
313+
int n_warps = reduce_extent / warp_size_;
314+
std::vector<Buffer> local_bufs;
344315

345-
// Emit reductions within a warp.
346-
int start_offset = 1;
347-
while (start_offset * 2 < reduce_extent) {
348-
start_offset *= 2;
349-
}
350-
for (int offset = start_offset; offset > 0; offset /= 2) {
351-
// Load reduction values, no synchronization needed.
352-
Array<PrimExpr> a, b;
316+
// 1. Create the staging buffer in shared memory.
317+
std::vector<Buffer> staging_shared_bufs;
318+
staging_shared_bufs.reserve(size);
353319
for (size_t i = 0; i < size; ++i) {
354-
Buffer shared_buf = shared_bufs[i];
355-
BufferLoad val(shared_buf, zero_indices);
356-
ICHECK_EQ(val->dtype, types[i]);
357-
a.push_back(val);
358-
359-
// __shfl_*sync calls shall not appear in if_then_else expressions
360-
// as this is causing extra divergency. E.g.
361-
//
362-
// v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
363-
//
364-
// behaves differently from
365-
//
366-
// int t = __shfl_sync(mask, v1, 0);
367-
// v1 = (v2 < v3) ? v3 : t;
368-
//
369-
// The former may cause dead lock as there is a divergent
370-
// branch with a warp sync call inside.
371-
//
372-
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
373-
Buffer local_buf = local_bufs[i];
374-
Stmt s = BufferStore(local_buf, other, zero_indices);
375-
seq.push_back(s);
376-
377-
BufferLoad load = BufferLoad(local_buf, zero_indices);
378-
ICHECK_EQ(load->dtype, types[i]);
379-
b.push_back(load);
320+
Buffer staging_shared_buf = decl_buffer(
321+
/*shape=*/{make_const(reduce_index->dtype, n_warps * group_extent)},
322+
/*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", /*storage_scope=*/"shared");
323+
staging_shared_bufs.push_back(staging_shared_buf);
324+
new_alloc_bufs.push_back(staging_shared_buf);
380325
}
381326

382-
// Do reductions.
383-
Array<PrimExpr> ret = (*combiner)(a, b);
327+
// 2. First round of allreduce.
328+
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
329+
values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq);
330+
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
384331

385-
// Store the reduction result to itself.
386-
std::vector<Stmt> stores(size);
332+
// 3. Write allreduce results to staging buffer.
333+
std::vector<Stmt> write_staging_buf;
334+
write_staging_buf.reserve(size);
387335
for (size_t i = 0; i < size; ++i) {
388-
Buffer buf = shared_bufs[i];
389-
stores[i] = BufferStore(buf, ret[i], zero_indices);
336+
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
337+
write_staging_buf.push_back(BufferStore(
338+
/*buffer=*/staging_shared_bufs[i],
339+
/*value=*/reduce_results[i],
340+
/*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
390341
}
342+
PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0);
343+
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
344+
seq.push_back(SyncThread("shared"));
391345

392-
// During the sub-warp reduction, values from inactive threads could be read,
393-
// which is an undefined behavior according to the cuda document.
394-
//
395-
// In practice, the return value are usually 0, which does no harm to sum reduction.
396-
// However, the result can be incorrect in max or prod reduction.
397-
// Therefore an additional range check has to be performed to ensure the correctness.
398-
if (offset * 2 > reduce_extent) {
399-
PrimExpr cond = reduce_index + offset < reduce_extent;
400-
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
401-
} else {
402-
seq.push_back(SeqStmt::Flatten(stores));
346+
// 4. Load staging buffer.
347+
// Second round of allreduce.
348+
for (size_t i = 0; i < size; ++i) {
349+
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i], /*indices=*/{reduce_index});
403350
}
351+
if (n_warps < warp_size_) {
352+
mask = mask & (((1 << n_warps) - 1) << group_index);
353+
}
354+
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
355+
values, types, combiner, reduce_index, n_warps, group_index,
356+
/*mask=*/mask,
357+
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
358+
&seq);
359+
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
404360
}
405361

406-
// Broadcast the reduction result from lane 0 to all other lanes.
407-
// This avoids to emit predicated stores, as all threads are
408-
// uniformly writing the same result.
409-
//
410-
for (size_t i = 0; i < size; ++i) {
411-
Buffer buf = shared_bufs[i];
412-
PrimExpr val = BufferLoad(buf, zero_indices);
413-
ICHECK_EQ(val->dtype, types[i]);
414-
PrimExpr splat =
415-
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
416-
seq.push_back(BufferStore(buf, splat, zero_indices));
417-
}
418-
419-
// Update existing allocations.
362+
// Write back allreduce results and update existing allocations.
420363
for (size_t i = 0; i < size; ++i) {
421364
ICHECK(!load_remap_.count(buffers[i]->data.get()));
422365
PrimExpr pred = const_true(types[i].lanes());
423-
Buffer buf = shared_bufs[i];
424-
PrimExpr val = BufferLoad(buf, zero_indices);
425-
ICHECK_EQ(val->dtype, types[i]);
426-
load_remap_[buffers[i]->data.get()] = val;
366+
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
367+
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
368+
load_remap_[buffers[i]->data.get()] = reduce_results[i];
369+
427370
Array<PrimExpr> extents{PrimExpr(1)};
428371
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
429372
alloc_remap_[buffers[i]->data.get()] = node;
@@ -432,6 +375,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
432375
warp_allocs_.insert(node.get());
433376
}
434377
} else {
378+
std::vector<Buffer> shared_bufs(size);
435379
if (reduce_extent == 1) {
436380
// special case, no reduction is needed.
437381
std::vector<Stmt> stores;
@@ -447,7 +391,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
447391
Buffer buffer = decl_buffer({1}, types[idx], "red_buf" + std::to_string(idx));
448392

449393
shared_bufs[idx] = buffer;
450-
shared_buffer_vars[idx] = buffer->data;
451394

452395
PrimExpr pred = const_true(types[idx].lanes());
453396
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
@@ -473,14 +416,153 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
473416

474417
// Fix all local allocations as all statements are built.
475418
Stmt body = SeqStmt::Flatten(seq);
476-
for (Buffer buf : local_bufs) {
419+
for (Buffer buf : new_alloc_bufs) {
477420
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
478-
new_storage_scopes_[buf->data.get()] = "local";
421+
String scope = buf.scope();
422+
if (buf.scope() != "shared") {
423+
new_storage_scopes_[buf->data.get()] = "local";
424+
}
479425
}
480426

481427
return body;
482428
}
483429

430+
std::pair<std::vector<PrimExpr>, std::vector<Buffer>> MakeWarpAllreduce(
431+
std::vector<PrimExpr> src_values, //
432+
std::vector<DataType> dtypes, //
433+
const CommReducerNode* combiner, //
434+
PrimExpr reduce_index, int reduce_extent, //
435+
PrimExpr group_index, //
436+
PrimExpr mask, Optional<PrimExpr> predicate, //
437+
std::vector<Stmt>* seq) {
438+
int n_buffers = src_values.size();
439+
440+
std::vector<Buffer> shared_bufs;
441+
std::vector<Buffer> local_bufs;
442+
shared_bufs.reserve(n_buffers);
443+
444+
// This is the index to the reduction variable, one reduction
445+
// variable per warp. Local scope seems easier to reason without
446+
// relying on a pattern match pass to fix it later.
447+
Array<PrimExpr> zero_indices = {0};
448+
449+
std::vector<Stmt> load_values;
450+
load_values.reserve(n_buffers);
451+
for (int idx = 0; idx < n_buffers; ++idx) {
452+
Array<PrimExpr> shape = {1};
453+
454+
Buffer buffer = decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx));
455+
Var buffer_var = buffer->data;
456+
457+
shared_bufs.push_back(buffer);
458+
459+
PrimExpr pred = const_true(dtypes[idx].lanes());
460+
load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
461+
462+
// Uses a local variable to store the shuffled data. Later
463+
// on, an allocation will be built for this local variable.
464+
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx)));
465+
}
466+
467+
if (predicate.defined()) {
468+
seq->push_back(IfThenElse(predicate.value(), SeqStmt::Flatten(load_values)));
469+
} else {
470+
seq->insert(seq->end(), load_values.begin(), load_values.end());
471+
}
472+
473+
// The mask for this reducer, as this reducer may sit inside
474+
// a divergent control flow. Here it uses a variable to cache the current
475+
// active channels.
476+
Buffer mask_buffer = decl_buffer({1}, mask->dtype, "mask");
477+
{
478+
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
479+
// Push the buffer description. Later this will have an
480+
// allocation built for it.
481+
local_bufs.push_back(mask_buffer);
482+
}
483+
484+
// Emit reductions within a warp.
485+
int start_offset = 1;
486+
while (start_offset * 2 < reduce_extent) {
487+
start_offset *= 2;
488+
}
489+
for (int offset = start_offset; offset > 0; offset /= 2) {
490+
// Load reduction values, no synchronization needed.
491+
Array<PrimExpr> a, b;
492+
for (int i = 0; i < n_buffers; ++i) {
493+
Buffer shared_buf = shared_bufs[i];
494+
BufferLoad val(shared_buf, zero_indices);
495+
ICHECK_EQ(val->dtype, dtypes[i]);
496+
a.push_back(val);
497+
498+
// __shfl_*sync calls shall not appear in if_then_else expressions
499+
// as this is causing extra divergency. E.g.
500+
//
501+
// v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
502+
//
503+
// behaves differently from
504+
//
505+
// int t = __shfl_sync(mask, v1, 0);
506+
// v1 = (v2 < v3) ? v3 : t;
507+
//
508+
// The former may cause dead lock as there is a divergent
509+
// branch with a warp sync call inside.
510+
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
511+
Buffer local_buf = local_bufs[i];
512+
Stmt s = BufferStore(local_buf, other, zero_indices);
513+
seq->push_back(s);
514+
515+
BufferLoad load = BufferLoad(local_buf, zero_indices);
516+
ICHECK_EQ(load->dtype, dtypes[i]);
517+
b.push_back(load);
518+
}
519+
520+
// Do reductions.
521+
Array<PrimExpr> ret = (*combiner)(a, b);
522+
523+
// Store the reduction result to itself.
524+
std::vector<Stmt> stores;
525+
stores.reserve(n_buffers);
526+
for (int i = 0; i < n_buffers; ++i) {
527+
Buffer buf = shared_bufs[i];
528+
stores.push_back(BufferStore(buf, ret[i], zero_indices));
529+
}
530+
531+
// During the sub-warp reduction, values from inactive threads could be read,
532+
// which is an undefined behavior according to the cuda document.
533+
//
534+
// In practice, the return value are usually 0, which does no harm to sum reduction.
535+
// However, the result can be incorrect in max or prod reduction.
536+
// Therefore an additional range check has to be performed to ensure the correctness.
537+
if (offset * 2 > reduce_extent) {
538+
PrimExpr cond = reduce_index + offset < reduce_extent;
539+
seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
540+
} else {
541+
seq->push_back(SeqStmt::Flatten(stores));
542+
}
543+
}
544+
545+
// Broadcast the reduction result from lane 0 to all other lanes.
546+
// This avoids to emit predicated stores, as all threads are
547+
// uniformly writing the same result.
548+
for (int i = 0; i < n_buffers; ++i) {
549+
Buffer buf = shared_bufs[i];
550+
PrimExpr val = BufferLoad(buf, zero_indices);
551+
ICHECK_EQ(val->dtype, dtypes[i]);
552+
PrimExpr splat =
553+
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
554+
seq->push_back(BufferStore(buf, splat, zero_indices));
555+
}
556+
557+
std::vector<PrimExpr> reduce_results;
558+
reduce_results.reserve(n_buffers);
559+
for (int i = 0; i < n_buffers; ++i) {
560+
reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices));
561+
}
562+
563+
return {reduce_results, local_bufs};
564+
}
565+
484566
// make allreduce.
485567
Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector<DataType>& types,
486568
const Array<Buffer>& shared_bufs, PrimExpr reduce_index,
@@ -676,8 +758,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
676758
if (reduce_extent == 1) {
677759
return false; // no need to warp reduce
678760
} else {
679-
if (warp_size_ % reduce_extent == 0) {
680-
return true; // warp size is multiple of reduce extent
761+
if (warp_size_ % reduce_extent == 0 || reduce_extent % warp_size_ == 0) {
762+
return true; // warp size is multiple or factor of reduce extent
681763
} else {
682764
return group_extent == 1 && reduce_extent <= warp_size_;
683765
}

0 commit comments

Comments
 (0)