Skip to content

Commit 68be158

Browse files
[ROCm] Some fixes of ROCm codegen (#16404)
- Handle tvm_thread_invariant as no op. - `llvm.amdgcn.ds.bpermute` requires i32 as its input, but it can handle all 32 bit types - ocml intrinsics lead to incorrect codegen when used with vectorization, remove it and use llvm intrinsics instead
1 parent a7dd32c commit 68be158

File tree

4 files changed

+108
-36
lines changed

4 files changed

+108
-36
lines changed

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
14761476
} else if (op->op.same_as(builtin::assume())) {
14771477
llvm::Value* cond = MakeValue(op->args[0]);
14781478
return builder_->CreateAssumption(cond);
1479+
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
1480+
return MakeValue(op->args[0]);
14791481
} else {
14801482
LOG(FATAL) << "unknown intrinsic " << op->op;
14811483
}

src/target/llvm/intrin_rule_rocm.cc

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
8989
index = self + delta;
9090
index = Select((self & (width - 1)) + delta >= width, self, index);
9191
}
92+
// reinterprete var as int32
93+
bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32;
94+
PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var);
9295
PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(),
93-
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
96+
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source});
97+
if (!is_int32) {
98+
res = reinterpret(var.dtype(), res);
99+
}
94100
return res;
95101
}
96102

@@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
114120
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);
115121

116122
TVM_REGISTER_OP("tir.floor")
117-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
123+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
124+
DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
118125

119126
TVM_REGISTER_OP("tir.ceil")
120-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
127+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
128+
DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
121129

122130
TVM_REGISTER_OP("tir.round")
123-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
131+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
132+
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
124133

125134
TVM_REGISTER_OP("tir.nearbyint")
126-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
135+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
136+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
127137

128138
TVM_REGISTER_OP("tir.trunc")
129-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
139+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
140+
DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
130141

131142
TVM_REGISTER_OP("tir.fabs")
132-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
143+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
144+
DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
133145

134-
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
135-
DispatchPureExternOCML);
146+
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
147+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
136148

137149
TVM_REGISTER_OP("tir.exp2")
138-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
150+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
151+
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
139152

140-
TVM_REGISTER_OP("tir.exp10")
141-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
153+
// TVM_REGISTER_OP("tir.exp10")
154+
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
155+
// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);
142156

143-
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
144-
DispatchPureExternOCML);
157+
// TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
158+
// DispatchPureExternOCML);
145159

146160
TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
147161
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
148162

149-
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
150-
DispatchPureExternOCML);
163+
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
164+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
151165

152166
TVM_REGISTER_OP("tir.log2")
153-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
167+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
168+
DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
154169

155170
TVM_REGISTER_OP("tir.log10")
156-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
171+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
172+
DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
157173

158174
TVM_REGISTER_OP("tir.sqrt")
159-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
175+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
176+
DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
160177

161-
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
162-
DispatchPureExternOCML);
178+
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
179+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
163180

164-
TVM_REGISTER_OP("tir.tanh")
165-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
181+
// TVM_REGISTER_OP("tir.tanh")
182+
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
166183

167-
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
168-
DispatchPureExternOCML);
184+
// TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
185+
// DispatchPureExternOCML);
169186

170-
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
171-
DispatchPureExternOCML);
187+
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
188+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
172189

173-
TVM_REGISTER_OP("tir.cosh")
174-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
190+
// TVM_REGISTER_OP("tir.cosh")
191+
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
175192

176-
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
177-
DispatchPureExternOCML);
193+
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
194+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
178195

179-
TVM_REGISTER_OP("tir.sinh")
180-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
196+
// TVM_REGISTER_OP("tir.sinh")
197+
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
181198

182-
TVM_REGISTER_OP("tir.atan")
183-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
199+
// TVM_REGISTER_OP("tir.atan")
200+
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
184201

185202
} // namespace llvm
186203
} // namespace codegen

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
730730
// rocm only supports 32 bit operands for shuffling at the moment
731731
if ((target_->kind->name == "rocm") &&
732732
(std::any_of(types.begin(), types.end(), [](DataType ty) {
733-
if ((ty.is_vector()) || !ty.is_int()) return true;
733+
if (ty.is_vector()) return ty.bits() * ty.lanes() != 32;
734734
return ty.bits() != 32;
735735
}))) {
736736
return false;

tests/python/codegen/test_target_codegen_rocm.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tvm import te
2020
import numpy as np
2121
import unittest
22+
from tvm.script import tir as T
2223

2324
tx = te.thread_axis("threadIdx.x")
2425
ty = te.thread_axis("threadIdx.y")
@@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes):
130131
check_rocm("float16", 64, 2)
131132

132133

134+
@tvm.testing.requires_rocm
135+
def test_rocm_warp_shuffle():
136+
@T.prim_func
137+
def func(
138+
A_handle: T.handle,
139+
):
140+
A = T.match_buffer(A_handle, (32,), dtype="float32")
141+
142+
for bx in T.thread_binding(1, thread="blockIdx.x"):
143+
for tx in T.thread_binding(32, thread="threadIdx.x"):
144+
with T.block("test"):
145+
A_local = T.alloc_buffer((1,), "float32", scope="local")
146+
mask = T.alloc_buffer((1,), "uint32", scope="local")
147+
t0 = T.alloc_buffer((1,), "float32", scope="local")
148+
149+
A_local[0] = A[tx]
150+
A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32)
151+
A[tx] = A_local[0]
152+
153+
mod = tvm.build(func, target="rocm")
154+
dev = tvm.rocm(0)
155+
a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev)
156+
mod(a)
157+
tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0])
158+
159+
160+
@tvm.testing.requires_rocm
161+
def test_rocm_vectorized_exp():
162+
@T.prim_func
163+
def func(
164+
A_handle: T.handle,
165+
B_handle: T.handle,
166+
):
167+
A = T.match_buffer(A_handle, (4,), dtype="float32")
168+
B = T.match_buffer(B_handle, (4,), dtype="float32")
169+
170+
for bx in T.thread_binding(1, thread="blockIdx.x"):
171+
for tx in T.thread_binding(1, thread="threadIdx.x"):
172+
with T.block("test"):
173+
for i in T.vectorized(0, 4):
174+
B[i] = T.exp2(A[i])
175+
176+
mod = tvm.build(func, target="rocm")
177+
dev = tvm.rocm(0)
178+
a = tvm.nd.array(np.ones((4,)).astype("float32"), dev)
179+
b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev)
180+
mod(a, b)
181+
tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy()))
182+
183+
133184
if __name__ == "__main__":
134185
test_rocm_cross_thread_reduction()
135186
test_rocm_inf_nan()
136187
test_rocm_reduction_binding()
137188
test_rocm_copy()
138189
test_rocm_vectorize_add()
190+
test_rocm_warp_shuffle()
191+
test_rocm_vectorized_exp()

0 commit comments

Comments
 (0)