Skip to content

Commit cd1bcc4

Browse files
committed
[TIR] Allreduce broadcast result to each thread in multi-warp case
PR apache#15327 introduces the warp-level primitive support in multi-warp allreduce. However, due to the specialty of the two-stage shuffle-down reduction implementation of the allreduce in multi-warp scenarios, PR apache#15327 did not broadcast the allreduce result to each reduction thread. This behavior does not align with the semantics of allreduce and is not ideal for many use cases. Therefore, this PR completes the implementation by inserting a stage of writing the reduction results to shared memory, so that each reduction thread across all the reduction warps can access the reduction results. This shared memory write-back stage will only be inserted in multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync` is used to broadcast the reduction results across reduction threads. Since in multi-warp settings we cannot leverage warp-level primitives to broadcast the value, we can only make use of shared memory. The numerical correctness are verified locally.
1 parent 03fecba commit cd1bcc4

File tree

2 files changed

+69
-75
lines changed

2 files changed

+69
-75
lines changed

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,6 @@
3838
namespace tvm {
3939
namespace tir {
4040

41-
class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope {
42-
public:
43-
explicit UpdatePointerStorageScopeAllReduce(
44-
const std::unordered_map<const VarNode*, String>& new_storage_scopes)
45-
: UpdatePointerStorageScope(new_storage_scopes) {}
46-
47-
Stmt VisitStmt_(const AllocateNode* op) final {
48-
auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
49-
auto new_scope = GetPtrStorageScope(remapped);
50-
if (new_scope != GetPtrStorageScope(op->buffer_var)) {
51-
Stmt body = StmtExprMutator::VisitStmt(op->body);
52-
if (new_scope == "shared") {
53-
// use volatile access to shared buffer.
54-
body = AttrStmt(remapped, attr::volatile_scope, 1, body);
55-
}
56-
return Allocate(remapped, op->dtype, op->extents, op->condition, body, op->annotations);
57-
}
58-
return StmtExprMutator::VisitStmt_(op);
59-
}
60-
};
61-
6241
class ThreadAllreduceBuilder final : public StmtExprMutator {
6342
public:
6443
explicit ThreadAllreduceBuilder(const TargetNode* target)
@@ -98,11 +77,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
9877

9978
if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
10079
const AllocateNode* repl = it->second.as<AllocateNode>();
101-
if (warp_allocs_.count(repl)) {
102-
new_storage_scopes_[repl->buffer_var.get()] = "local";
103-
} else {
104-
new_storage_scopes_[repl->buffer_var.get()] = "shared";
105-
}
10680
auto write_ptr = node.CopyOnWrite();
10781
write_ptr->buffer_var = repl->buffer_var;
10882
write_ptr->dtype = repl->dtype;
@@ -161,8 +135,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
161135
return std::move(store);
162136
}
163137

164-
std::unordered_map<const VarNode*, String> new_storage_scopes_;
165-
166138
private:
167139
// Thread entry
168140
struct ThreadEntry {
@@ -310,6 +282,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
310282
// In the second stage we use the first 16 lanes of the first warp to reduce
311283
// the remaining elements, and this reduction can also be optimized by
312284
// shuffle_down warp-level primitives.
285+
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
313286
if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
314287
std::vector<PrimExpr> reduce_results;
315288
DataType mask_dtype = DataType::UInt(32);
@@ -322,6 +295,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
322295
}
323296
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
324297
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
298+
299+
// Broadcast the reduction result from lane 0 to all other lanes.
300+
// This avoids to emit predicated stores, as all threads are
301+
// uniformly writing the same result.
302+
for (int i = 0; i < size; ++i) {
303+
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
304+
PrimExpr val = BufferLoad(buf, {zero_index});
305+
ICHECK_EQ(val->dtype, types[i]);
306+
PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val,
307+
reduce_extent * group_index);
308+
seq.push_back(BufferStore(buf, splat, {zero_index}));
309+
}
325310
} else {
326311
int n_warps = reduce_extent / warp_size_;
327312
std::vector<Buffer> local_bufs;
@@ -352,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
352337
/*value=*/reduce_results[i],
353338
/*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
354339
}
355-
PrimExpr cond = floormod(reduce_index, warp_size_) == make_const(reduce_index->dtype, 0);
340+
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
356341
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
357342
seq.push_back(SyncThread("shared"));
358343

@@ -369,6 +354,23 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
369354
/*predicate=*/reduce_index < make_const(reduce_index->dtype, group_extent * n_warps),
370355
&seq);
371356
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
357+
358+
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
359+
// the allreduce results so each thread can access.
360+
std::vector<Stmt> write_result;
361+
write_result.reserve(size);
362+
for (size_t i = 0; i < size; ++i) {
363+
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
364+
Buffer broadcast_shared_buf = decl_buffer(
365+
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
366+
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
367+
write_result.push_back(
368+
BufferStore(broadcast_shared_buf, reduce_results[i], {zero_index}));
369+
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
370+
reduce_results[i] = BufferLoad(broadcast_shared_buf, {zero_index});
371+
}
372+
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
373+
seq.push_back(SyncThread("shared"));
372374
}
373375

374376
// Write back allreduce results and update existing allocations.
@@ -379,12 +381,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
379381
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
380382
load_remap_[buffers[i]->data.get()] = reduce_results[i];
381383

382-
Array<PrimExpr> extents{PrimExpr(1)};
383-
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
384+
auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
384385
alloc_remap_[buffers[i]->data.get()] = node;
385386
var_remap_[buffers[i]->data.get()] = buf->data;
386387
buf_remap_[buffers[i].get()] = buf;
387-
warp_allocs_.insert(node.get());
388388
}
389389
} else {
390390
std::vector<Buffer> shared_bufs(size);
@@ -426,9 +426,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
426426
Stmt body = SeqStmt::Flatten(seq);
427427
for (Buffer buf : new_alloc_bufs) {
428428
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
429-
if (buf.scope() != "shared") {
430-
new_storage_scopes_[buf->data.get()] = "local";
431-
}
432429
}
433430

434431
return body;
@@ -457,12 +454,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
457454
std::vector<Stmt> load_values;
458455
load_values.reserve(n_buffers);
459456
for (int idx = 0; idx < n_buffers; ++idx) {
460-
shared_bufs.push_back(decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx)));
457+
shared_bufs.push_back(
458+
decl_buffer(shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
461459
load_values.push_back(BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
462460

463461
// Uses a local variable to store the shuffled data. Later
464462
// on, an allocation will be built for this local variable.
465-
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx)));
463+
local_bufs.push_back(decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
466464
}
467465

468466
if (predicate.defined()) {
@@ -474,7 +472,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
474472
// The mask for this reducer, as this reducer may sit inside
475473
// a divergent control flow. Here it uses a variable to cache the current
476474
// active channels.
477-
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask");
475+
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
478476
{
479477
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
480478
// Push the buffer description. Later this will have an
@@ -543,18 +541,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
543541
}
544542
}
545543

546-
// Broadcast the reduction result from lane 0 to all other lanes.
547-
// This avoids to emit predicated stores, as all threads are
548-
// uniformly writing the same result.
549-
for (int i = 0; i < n_buffers; ++i) {
550-
Buffer buf = shared_bufs[i];
551-
PrimExpr val = BufferLoad(buf, zero_indices);
552-
ICHECK_EQ(val->dtype, dtypes[i]);
553-
PrimExpr splat =
554-
WarpShuffle(builtin::tvm_warp_shuffle(), mask_buffer, val, reduce_extent * group_index);
555-
seq->push_back(BufferStore(buf, splat, zero_indices));
556-
}
557-
558544
std::vector<PrimExpr> reduce_results;
559545
reduce_results.reserve(n_buffers);
560546
for (int i = 0; i < n_buffers; ++i) {
@@ -791,8 +777,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
791777
std::unordered_map<const VarNode*, Var> var_remap_;
792778
// Buffer remap
793779
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
794-
// Allocate from warp reductions
795-
std::unordered_set<const void*> warp_allocs_;
796780
// Internal analyzer
797781
arith::Analyzer analyzer_;
798782
};
@@ -806,9 +790,7 @@ Pass LowerThreadAllreduce() {
806790
ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute";
807791
const TargetNode* target_node = target.as<TargetNode>();
808792
ThreadAllreduceBuilder thread_all_reduce(target_node);
809-
auto reduce_body = thread_all_reduce(n->body);
810-
n->body =
811-
UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body);
793+
n->body = thread_all_reduce(n->body);
812794
return f;
813795
};
814796
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});

tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,14 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
386386
T.func_attr({"target": T.target("cuda", host="llvm")})
387387
for i in range(128):
388388
threadIdx_x = T.launch_thread("threadIdx.x", 128)
389-
red_buf0 = T.allocate([1], "float32", "local")
390-
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
389+
red_result = T.allocate([1], "float32", "shared")
390+
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
391391
with T.attr(
392392
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
393393
"reduce_scope",
394394
T.reinterpret("handle", T.uint64(0)),
395395
):
396+
red_buf0 = T.allocate([1], "float32", "local")
396397
mask = T.allocate([1], "uint32", "local")
397398
t0 = T.allocate([1], "float32", "local")
398399
red_buf0_1 = T.allocate([1], "float32", "local")
@@ -415,11 +416,11 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
415416
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
416417
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
417418
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
418-
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
419419
red_buf_staging_1 = T.Buffer((4,), data=red_buf_staging, scope="shared")
420420
if threadIdx_x % 32 == 0:
421421
red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
422422
T.tvm_storage_sync("shared")
423+
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
423424
if threadIdx_x < 4:
424425
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
425426
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -429,10 +430,12 @@ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,), "float32"))
429430
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
430431
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
431432
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
432-
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
433+
if threadIdx_x == 0:
434+
red_result_1[0] = red_buf0_3[0]
435+
T.tvm_storage_sync("shared")
433436
if threadIdx_x == 0:
434437
B_1 = T.Buffer((128,), data=B.data)
435-
B_1[i] = red_buf0_3[0]
438+
B_1[i] = red_result_1[0]
436439

437440

438441
class TestMultiWarpReduce2(BaseCompare):
@@ -459,13 +462,14 @@ def before(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
459462
def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
460463
T.func_attr({"target": T.target("cuda", host="llvm")})
461464
threadIdx_x = T.launch_thread("threadIdx.x", 1024)
462-
red_buf0 = T.allocate([1], "float32", "local")
463-
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
465+
red_result = T.allocate([1], "float32", "shared")
466+
red_result_1 = T.Buffer((1,), data=red_result, scope="shared")
464467
with T.attr(
465468
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
466469
"reduce_scope",
467470
T.reinterpret("handle", T.uint64(0)),
468471
):
472+
red_buf0 = T.allocate([1], "float32", "local")
469473
mask = T.allocate([1], "uint32", "local")
470474
t0 = T.allocate([1], "float32", "local")
471475
red_buf0_1 = T.allocate([1], "float32", "local")
@@ -488,11 +492,11 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
488492
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
489493
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
490494
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
491-
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 0, 32, 32)
492495
red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
493496
if threadIdx_x % 32 == 0:
494497
red_buf_staging_1[threadIdx_x // 32] = red_buf0_2[0]
495498
T.tvm_storage_sync("shared")
499+
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
496500
if threadIdx_x < 32:
497501
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
498502
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -508,10 +512,12 @@ def expected(A: T.Buffer((1, 1024), "float32"), B: T.Buffer((1,), "float32")):
508512
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
509513
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
510514
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
511-
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 0, 32, 32)
515+
if threadIdx_x == 0:
516+
red_result_1[0] = red_buf0_3[0]
517+
T.tvm_storage_sync("shared")
512518
if threadIdx_x == 0:
513519
B_1 = T.Buffer((1,), data=B.data)
514-
B_1[0] = red_buf0_3[0]
520+
B_1[0] = red_result_1[0]
515521

516522

517523
class TestMultiGroupMultiWarpReduction(BaseCompare):
@@ -543,14 +549,15 @@ def before(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
543549
def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
544550
T.func_attr({"target": T.target("cuda", host="llvm")})
545551
threadIdx_y = T.launch_thread("threadIdx.y", 4)
546-
red_buf0 = T.allocate([1], "float32", "local")
552+
red_result = T.allocate([4], "float32", "shared")
547553
threadIdx_x = T.launch_thread("threadIdx.x", 128)
548-
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
554+
red_result_1 = T.Buffer((4,), data=red_result, scope="shared")
549555
with T.attr(
550556
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
551557
"reduce_scope",
552558
T.reinterpret("handle", T.uint64(0)),
553559
):
560+
red_buf0 = T.allocate([1], "float32", "local")
554561
mask = T.allocate([1], "uint32", "local")
555562
t0 = T.allocate([1], "float32", "local")
556563
red_buf0_1 = T.allocate([1], "float32", "local")
@@ -573,11 +580,11 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
573580
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
574581
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
575582
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
576-
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
577583
red_buf_staging_1 = T.Buffer((16,), data=red_buf_staging, scope="shared")
578584
if threadIdx_x % 32 == 0:
579585
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
580586
T.tvm_storage_sync("shared")
587+
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
581588
if threadIdx_x < 16:
582589
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
583590
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -589,10 +596,12 @@ def expected(A: T.Buffer((4, 128), "float32"), B: T.Buffer((4,), "float32")):
589596
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
590597
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
591598
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
592-
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 4 * threadIdx_y, 32, 32)
599+
if threadIdx_x == 0:
600+
red_result_1[0] = red_buf0_3[0]
601+
T.tvm_storage_sync("shared")
593602
if threadIdx_x == 0:
594603
B_1 = T.Buffer((4,), data=B.data)
595-
B_1[threadIdx_y] = red_buf0_3[0]
604+
B_1[threadIdx_y] = red_result_1[0]
596605

597606

598607
class TestMultiGroupMultiWarpPredicatedReduction(BaseCompare):
@@ -626,19 +635,20 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
626635
T.func_attr({"target": T.target("cuda", host="llvm")})
627636
threadIdx_y = T.launch_thread("threadIdx.y", 2)
628637
in_thread_B = T.allocate([1], "float32", "local")
629-
red_buf0 = T.allocate([1], "float32", "local")
638+
red_result = T.allocate([2], "float32", "shared")
630639
threadIdx_x = T.launch_thread("threadIdx.x", 512)
631640
in_thread_B_1 = T.Buffer((1,), data=in_thread_B, scope="local")
632641
in_thread_B_1[0] = T.float32(0)
633642
if threadIdx_x < 70:
634643
A_1 = T.Buffer((140,), data=A.data)
635644
in_thread_B_1[0] = in_thread_B_1[0] + A_1[threadIdx_y * 70 + threadIdx_x]
636-
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
645+
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
637646
with T.attr(
638647
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
639648
"reduce_scope",
640649
T.reinterpret("handle", T.uint64(0)),
641650
):
651+
red_buf0 = T.allocate([1], "float32", "local")
642652
mask = T.allocate([1], "uint32", "local")
643653
t0 = T.allocate([1], "float32", "local")
644654
red_buf0_1 = T.allocate([1], "float32", "local")
@@ -660,11 +670,11 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
660670
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
661671
t0_2[0] = T.tvm_warp_shuffle_down(mask_2[0], red_buf0_2[0], 1, 32, 32)
662672
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
663-
red_buf0_2[0] = T.tvm_warp_shuffle(mask_2[0], red_buf0_2[0], 32 * threadIdx_y, 32, 32)
664673
red_buf_staging_1 = T.Buffer((32,), data=red_buf_staging, scope="shared")
665674
if threadIdx_x % 32 == 0:
666675
red_buf_staging_1[threadIdx_y * 16 + threadIdx_x // 32] = red_buf0_2[0]
667676
T.tvm_storage_sync("shared")
677+
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
668678
if threadIdx_x < 32:
669679
red_buf0_3[0] = red_buf_staging_1[threadIdx_x]
670680
mask_3 = T.Buffer((1,), "uint32", data=mask, scope="local")
@@ -680,10 +690,12 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
680690
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
681691
t0_3[0] = T.tvm_warp_shuffle_down(mask_3[0], red_buf0_3[0], 1, 32, 32)
682692
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
683-
red_buf0_3[0] = T.tvm_warp_shuffle(mask_3[0], red_buf0_3[0], 16 * threadIdx_y, 32, 32)
693+
if threadIdx_x == 0:
694+
red_result_1[0] = red_buf0_3[0]
695+
T.tvm_storage_sync("shared")
684696
if threadIdx_x == 0:
685697
B_1 = T.Buffer((2,), data=B.data)
686-
B_1[threadIdx_y] = red_buf0_3[0]
698+
B_1[threadIdx_y] = red_result_1[0]
687699

688700

689701
if __name__ == "__main__":

0 commit comments

Comments
 (0)