Skip to content

Commit fb416ef

Browse files
committed
[Codegen][Metal] Support metal warp-level primitive
This PR introduces the warp-level shuffle primitives used in Metal Shading Language, and uses them in the implementation of allreduce lowering. The introduced primitives are: * `simd_shuffle`, * `simd_shuffle_up`, * `simd_shuffle_down`. See section 6.9.2 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf for details. The correctness are validated by `test_allreduce_cuda` with the backend changed to Metal. Given we do not have Metal CI tests, the correctness is checked only locally. Given the Metal shuffle primitives do not support (or need) masking, the pass LowerThreadAllreduce is updated to support such backend which does not have masks. One unit test for metal is added to ensure that no mask is used.
1 parent 9ff74fb commit fb416ef

File tree

3 files changed

+180
-11
lines changed

3 files changed

+180
-11
lines changed

src/target/source/intrin_rule_metal.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ namespace codegen {
3030
namespace intrin {
3131
using tir::FLowerIntrinsic;
3232

33+
struct MetalWarpIntrinsic {
34+
const Op operator()(DataType t, const Op& orig_op) const {
35+
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
36+
return Op::Get("tir.metal.simd_shuffle");
37+
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
38+
return Op::Get("tir.metal.simd_shuffle_up");
39+
} else {
40+
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
41+
return Op::Get("tir.metal.simd_shuffle_down");
42+
}
43+
}
44+
};
45+
46+
template <typename T>
47+
static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
48+
const CallNode* call = e.as<CallNode>();
49+
ICHECK(call != nullptr);
50+
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
51+
Array<PrimExpr> metal_args{{call->args[1], call->args[2]}};
52+
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args);
53+
}
54+
3355
TVM_REGISTER_OP("tir.floor")
3456
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
3557

@@ -95,6 +117,37 @@ TVM_REGISTER_OP("tir.cosh")
95117

96118
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);
97119

120+
TVM_REGISTER_OP("tir.tvm_warp_shuffle")
121+
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
122+
123+
TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
124+
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
125+
126+
TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
127+
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
128+
129+
// Register low-level builtin ops.
130+
TVM_REGISTER_OP("tir.metal.simd_shuffle")
131+
.set_num_inputs(2)
132+
.add_argument("var", "Expr", "The variable to sync.")
133+
.add_argument("lane", "Expr", "The source thread id.")
134+
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle")
135+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
136+
137+
TVM_REGISTER_OP("tir.metal.simd_shuffle_up")
138+
.set_num_inputs(2)
139+
.add_argument("var", "Expr", "The variable to sync.")
140+
.add_argument("delta", "Expr", "The source lane id offset to be added.")
141+
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_up")
142+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
143+
144+
TVM_REGISTER_OP("tir.metal.simd_shuffle_down")
145+
.set_num_inputs(2)
146+
.add_argument("var", "Expr", "The variable to sync.")
147+
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
148+
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_down")
149+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
150+
98151
} // namespace intrin
99152
} // namespace codegen
100153
} // namespace tvm

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
476476
// The mask for this reducer, as this reducer may sit inside
477477
// a divergent control flow. Here it uses a variable to cache the current
478478
// active channels.
479-
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
480-
{
481-
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
479+
Optional<Buffer> mask_buffer;
480+
if (need_warp_shuffle_mask_) {
481+
mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
482+
seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
482483
// Push the buffer description. Later this will have an
483484
// allocation built for it.
484-
local_bufs.push_back(mask_buffer);
485+
local_bufs.push_back(mask_buffer.value());
485486
}
486487

487488
// Emit reductions within a warp.
@@ -698,9 +699,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
698699
}
699700

700701
// Emit warp shuffle calls.
701-
PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
702+
PrimExpr WarpShuffle(const Op& op, Optional<Buffer> mask_buffer, PrimExpr val,
703+
PrimExpr delta_or_lane) {
702704
Array<PrimExpr> indices = {0};
703-
PrimExpr mask = BufferLoad(mask_buffer, indices);
705+
PrimExpr mask;
706+
if (mask_buffer.defined()) {
707+
mask = BufferLoad(mask_buffer.value(), indices);
708+
} else {
709+
mask = IntImm(DataType::Int(32), 0);
710+
}
704711
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
705712
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
706713
return Call(val.dtype(), op, args);
@@ -709,11 +716,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
709716
// Check if we can use warp level reduction.
710717
//
711718
// Note: The ROCm backend will only have warp reductions for now.
712-
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
719+
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
713720
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
714-
int contiguous_reduce_extent) const {
715-
// Only cuda target supports warp reductions.
716-
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
721+
int contiguous_reduce_extent) {
722+
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
723+
(target_->kind->name != "metal")) {
724+
return false;
725+
}
726+
727+
need_warp_shuffle_mask_ = target_->kind->name != "metal";
717728

718729
// rocm only supports 32 bit operands for shuffling at the moment
719730
if ((target_->kind->name == "rocm") &&
@@ -745,7 +756,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
745756
// whether reduce_extent and group_extent are valid for warp reduction.
746757
if (target_->kind->name == "rocm") {
747758
return reduce_extent == warp_size_;
748-
} else { // target_->kind->name == "cuda"
759+
} else {
749760
if (reduce_extent == 1) {
750761
return false; // no need to warp reduce
751762
} else {
@@ -769,6 +780,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
769780
int warp_size_{1};
770781
// The maximum number of threads of the device. "-1" denotes unknown.
771782
int max_num_threads_{-1};
783+
// A boolean indicating if the target supports warp-level masking.
784+
bool need_warp_shuffle_mask_;
772785

773786
// surrounding scope of thread extent.
774787
std::vector<const AttrStmtNode*> thread_extents_;

tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,5 +702,108 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
702702
B_1[threadIdx_y] = red_result_1[threadIdx_y]
703703

704704

705+
class TestMetalNoMask(BaseCompare):
706+
@T.prim_func
707+
def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
708+
T.func_attr(
709+
{
710+
"target": T.target(
711+
{
712+
"kind": "metal",
713+
"max_threads_per_block": 1024,
714+
"thread_warp_size": 32,
715+
"host": "llvm",
716+
}
717+
),
718+
}
719+
)
720+
blockIdx_x = T.launch_thread("blockIdx.x", 1)
721+
cross_thread_B = T.allocate([1], "float32", "local")
722+
threadIdx_z = T.launch_thread("threadIdx.z", 1)
723+
threadIdx_y = T.launch_thread("threadIdx.y", 2)
724+
threadIdx_x = T.launch_thread("threadIdx.x", 128)
725+
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
726+
with T.attr(
727+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
728+
"reduce_scope",
729+
T.reinterpret("handle", T.uint64(0)),
730+
):
731+
A_1 = T.Buffer((256,), data=A.data)
732+
T.tvm_thread_allreduce(
733+
T.uint32(1),
734+
A_1[threadIdx_y * 128 + threadIdx_x],
735+
T.bool(True),
736+
cross_thread_B_1[0],
737+
threadIdx_x,
738+
)
739+
if threadIdx_x == 0:
740+
B_1 = T.Buffer((2,), data=B.data)
741+
B_1[threadIdx_y] = cross_thread_B_1[0]
742+
743+
@T.prim_func
744+
def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
745+
T.func_attr(
746+
{
747+
"target": T.target(
748+
{
749+
"kind": "metal",
750+
"max_threads_per_block": 1024,
751+
"thread_warp_size": 32,
752+
"host": "llvm",
753+
}
754+
),
755+
}
756+
)
757+
blockIdx_x = T.launch_thread("blockIdx.x", 1)
758+
red_result = T.allocate([2], "float32", "shared")
759+
T.attr(red_result, "volatile_scope", 1)
760+
threadIdx_z = T.launch_thread("threadIdx.z", 1)
761+
threadIdx_y = T.launch_thread("threadIdx.y", 2)
762+
threadIdx_x = T.launch_thread("threadIdx.x", 128)
763+
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
764+
with T.attr(
765+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
766+
"reduce_scope",
767+
T.reinterpret("handle", T.uint64(0)),
768+
):
769+
red_buf0 = T.allocate([1], "float32", "local")
770+
t0 = T.allocate([1], "float32", "local")
771+
red_buf0_1 = T.allocate([1], "float32", "local")
772+
t0_1 = T.allocate([1], "float32", "local")
773+
red_buf_staging = T.allocate([8], "float32", "shared")
774+
red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
775+
A_1 = T.Buffer((256,), data=A.data)
776+
red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x]
777+
t0_2 = T.Buffer((1,), data=t0_1, scope="local")
778+
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32)
779+
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
780+
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32)
781+
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
782+
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32)
783+
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
784+
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32)
785+
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
786+
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32)
787+
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
788+
red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
789+
if threadIdx_x % 32 == 0:
790+
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
791+
T.tvm_storage_sync("shared")
792+
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
793+
if threadIdx_x < 4:
794+
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
795+
t0_3 = T.Buffer((1,), data=t0, scope="local")
796+
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32)
797+
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
798+
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32)
799+
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
800+
if threadIdx_x == 0:
801+
red_result_1[threadIdx_y] = red_buf0_3[0]
802+
T.tvm_storage_sync("shared")
803+
if threadIdx_x == 0:
804+
B_1 = T.Buffer((2,), data=B.data)
805+
B_1[threadIdx_y] = red_result_1[threadIdx_y]
806+
807+
705808
if __name__ == "__main__":
706809
tvm.testing.main()

0 commit comments

Comments
 (0)