Skip to content

Commit a7818f0

Browse files
junrushaospectrometerHBHLeshengJin
committed
[Bugfix][Unity] Recover MSVC/NVCC/ROCm/Vulkan
This PR upstreams a few commits that recovers the unity branch from broken wheel packages. It includes the following changes: - Fix MSVC build in `pipe.h` where `DWORD` is not cast to proper return type (#306); - Fix MSVC build warnings on not recognizing "#pragma GCC" (#307); - Fix NVCC build warnings where it fails to infer if "[[noreturn]]" actually does not return (#308); - Fix ROCM/Vulkan backend which fails compilation for operators like group GEMM, paged attention, etc. (apache/tvm#16404, apache/tvm#16405) Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Lesheng Jin <[email protected]>
1 parent 042c44e commit a7818f0

File tree

10 files changed

+133
-31
lines changed

10 files changed

+133
-31
lines changed

include/tvm/runtime/logging.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ class LogFatal {
353353
#pragma disagnostic push
354354
#pragma warning(disable : 4722)
355355
#endif
356-
[[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION { GetEntry().Finalize(); }
356+
[[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION {
357+
GetEntry().Finalize();
358+
throw;
359+
}
357360
#ifdef _MSC_VER
358361
#pragma disagnostic pop
359362
#endif
@@ -366,7 +369,7 @@ class LogFatal {
366369
this->file_ = file;
367370
this->lineno_ = lineno;
368371
}
369-
[[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() {
372+
[[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() TVM_THROW_EXCEPTION {
370373
InternalError error(file_, lineno_, stream_.str());
371374
#if DMLC_LOG_BEFORE_THROW
372375
std::cerr << error.what() << std::endl;
@@ -560,15 +563,26 @@ std::unique_ptr<std::string> LogCheckFormat(const X& x, const Y& y) {
560563
return LogCheck##name<int, int>(x, y); \
561564
}
562565

566+
#if defined(__GNUC__) || defined(__clang__) // GCC and Clang
563567
#pragma GCC diagnostic push
564568
#pragma GCC diagnostic ignored "-Wsign-compare"
569+
#elif defined(_MSC_VER) // MSVC
570+
#pragma warning(push)
571+
#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch
572+
#endif
573+
565574
TVM_CHECK_FUNC(_LT, <)
566575
TVM_CHECK_FUNC(_GT, >)
567576
TVM_CHECK_FUNC(_LE, <=)
568577
TVM_CHECK_FUNC(_GE, >=)
569578
TVM_CHECK_FUNC(_EQ, ==)
570579
TVM_CHECK_FUNC(_NE, !=)
580+
581+
#if defined(__GNUC__) || defined(__clang__) // GCC and Clang
571582
#pragma GCC diagnostic pop
583+
#elif defined(_MSC_VER) // MSVC
584+
#pragma warning(pop)
585+
#endif
572586

573587
} // namespace detail
574588

src/support/pipe.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ class Pipe : public dmlc::Stream {
7777
size_t Read(void* ptr, size_t size) final {
7878
if (size == 0) return 0;
7979
#ifdef _WIN32
80-
auto fread = [&]() {
80+
auto fread = [&]() -> ssize_t {
8181
DWORD nread;
82-
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr)) return -1;
83-
return nread;
82+
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
83+
return static_cast<ssize_t>(-1);
84+
return static_cast<ssize_t>(nread);
8485
};
8586
DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread, GetLastErrorCode));
8687
ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " << GetLastError();
@@ -99,10 +100,11 @@ class Pipe : public dmlc::Stream {
99100
void Write(const void* ptr, size_t size) final {
100101
if (size == 0) return;
101102
#ifdef _WIN32
102-
auto fwrite = [&]() {
103+
auto fwrite = [&]() -> ssize_t {
103104
DWORD nwrite;
104-
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr)) return -1;
105-
return nwrite;
105+
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr))
106+
return static_cast<ssize_t>(-1);
107+
return static_cast<ssize_t>(nwrite);
106108
};
107109
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
108110
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
14761476
} else if (op->op.same_as(builtin::assume())) {
14771477
llvm::Value* cond = MakeValue(op->args[0]);
14781478
return builder_->CreateAssumption(cond);
1479+
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
1480+
return MakeValue(op->args[0]);
14791481
} else {
14801482
LOG(FATAL) << "unknown intrinsic " << op->op;
14811483
}

src/target/llvm/intrin_rule_rocm.cc

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
8989
index = self + delta;
9090
index = Select((self & (width - 1)) + delta >= width, self, index);
9191
}
92+
// reinterprete var as int32
93+
bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32;
94+
PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var);
9295
PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(),
93-
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
96+
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source});
97+
if (!is_int32) {
98+
res = reinterpret(var.dtype(), res);
99+
}
94100
return res;
95101
}
96102

@@ -114,28 +120,35 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
114120
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);
115121

116122
TVM_REGISTER_OP("tir.floor")
117-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
123+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
124+
DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
118125

119126
TVM_REGISTER_OP("tir.ceil")
120-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
127+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
128+
DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
121129

122130
TVM_REGISTER_OP("tir.round")
123-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
131+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
132+
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
124133

125134
TVM_REGISTER_OP("tir.nearbyint")
126-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
135+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
136+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
127137

128138
TVM_REGISTER_OP("tir.trunc")
129-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
139+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
140+
DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
130141

131142
TVM_REGISTER_OP("tir.fabs")
132-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
143+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
144+
DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
133145

134-
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
135-
DispatchPureExternOCML);
146+
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
147+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
136148

137149
TVM_REGISTER_OP("tir.exp2")
138-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
150+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
151+
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
139152

140153
TVM_REGISTER_OP("tir.exp10")
141154
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
@@ -146,35 +159,38 @@ TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
146159
TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
147160
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
148161

149-
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
150-
DispatchPureExternOCML);
162+
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
163+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
151164

152165
TVM_REGISTER_OP("tir.log2")
153-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
166+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
167+
DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
154168

155169
TVM_REGISTER_OP("tir.log10")
156-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
170+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
171+
DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
157172

158173
TVM_REGISTER_OP("tir.sqrt")
159-
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
174+
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
175+
DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
160176

161-
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
162-
DispatchPureExternOCML);
177+
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
178+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
163179

164180
TVM_REGISTER_OP("tir.tanh")
165181
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
166182

167183
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
168184
DispatchPureExternOCML);
169185

170-
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
171-
DispatchPureExternOCML);
186+
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
187+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
172188

173189
TVM_REGISTER_OP("tir.cosh")
174190
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
175191

176-
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
177-
DispatchPureExternOCML);
192+
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
193+
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
178194

179195
TVM_REGISTER_OP("tir.sinh")
180196
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

src/target/spirv/codegen_spirv.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
509509
spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class);
510510
ICHECK(var_map_.count(buffer_node));
511511
return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index));
512+
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
513+
return MakeValue(op->args[0]);
512514
} else {
513515
LOG(FATAL) << "Unresolved call " << op->op;
514516
}

src/target/spirv/intrin_rule_spirv.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ TVM_REGISTER_OP("tir.fabs")
8282
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
8383
DispatchGLSLPureIntrin<GLSLstd450Exp>);
8484

85+
TVM_REGISTER_OP("tir.exp2")
86+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Exp2>);
87+
8588
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
8689
DispatchGLSLPureIntrin<GLSLstd450Sin>);
8790

src/tir/transforms/lower_thread_allreduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
730730
// rocm only supports 32 bit operands for shuffling at the moment
731731
if ((target_->kind->name == "rocm") &&
732732
(std::any_of(types.begin(), types.end(), [](DataType ty) {
733-
if ((ty.is_vector()) || !ty.is_int()) return true;
733+
if (ty.is_vector()) return ty.bits() * ty.lanes() != 32;
734734
return ty.bits() != 32;
735735
}))) {
736736
return false;

src/tir/transforms/merge_shared_memory_allocations.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ namespace transform {
662662
Pass MergeSharedMemoryAllocations() {
663663
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
664664
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
665+
// disable this pass for Vulkan
666+
auto target = Target::Current(true);
667+
if (target.defined() && target->kind->name == "vulkan") {
668+
return f;
669+
}
665670
auto* n = f.CopyOnWrite();
666671
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
667672
return f;

src/tir/transforms/storage_rewrite.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1705,8 +1705,13 @@ namespace transform {
17051705
Pass StorageRewrite() {
17061706
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
17071707
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
1708+
// disable merge_static_smem for Vulkan
1709+
auto target = Target::Current(true);
1710+
if (target.defined() && target->kind->name == "vulkan") {
1711+
merge_static_smem = false;
1712+
}
17081713
auto* n = f.CopyOnWrite();
1709-
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem);
1714+
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, merge_static_smem);
17101715
// Parameters may not be rewritten, but internal allocations may.
17111716
// Vectorization of AllocateConst is currently disabled, as it has
17121717
// indexing issues for types that include padding (e.g. int8x3

tests/python/codegen/test_target_codegen_rocm.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tvm import te
2020
import numpy as np
2121
import unittest
22+
from tvm.script import tir as T
2223

2324
tx = te.thread_axis("threadIdx.x")
2425
ty = te.thread_axis("threadIdx.y")
@@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes):
130131
check_rocm("float16", 64, 2)
131132

132133

134+
@tvm.testing.requires_rocm
135+
def test_rocm_warp_shuffle():
136+
@T.prim_func
137+
def func(
138+
A_handle: T.handle,
139+
):
140+
A = T.match_buffer(A_handle, (32,), dtype="float32")
141+
142+
for bx in T.thread_binding(1, thread="blockIdx.x"):
143+
for tx in T.thread_binding(32, thread="threadIdx.x"):
144+
with T.block("test"):
145+
A_local = T.alloc_buffer((1,), "float32", scope="local")
146+
mask = T.alloc_buffer((1,), "uint32", scope="local")
147+
t0 = T.alloc_buffer((1,), "float32", scope="local")
148+
149+
A_local[0] = A[tx]
150+
A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32)
151+
A[tx] = A_local[0]
152+
153+
mod = tvm.build(func, target="rocm")
154+
dev = tvm.rocm(0)
155+
a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev)
156+
mod(a)
157+
tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0])
158+
159+
160+
@tvm.testing.requires_rocm
161+
def test_rocm_vectorized_exp():
162+
@T.prim_func
163+
def func(
164+
A_handle: T.handle,
165+
B_handle: T.handle,
166+
):
167+
A = T.match_buffer(A_handle, (4,), dtype="float32")
168+
B = T.match_buffer(B_handle, (4,), dtype="float32")
169+
170+
for bx in T.thread_binding(1, thread="blockIdx.x"):
171+
for tx in T.thread_binding(1, thread="threadIdx.x"):
172+
with T.block("test"):
173+
for i in T.vectorized(0, 4):
174+
B[i] = T.exp2(A[i])
175+
176+
mod = tvm.build(func, target="rocm")
177+
dev = tvm.rocm(0)
178+
a = tvm.nd.array(np.ones((4,)).astype("float32"), dev)
179+
b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev)
180+
mod(a, b)
181+
tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy()))
182+
183+
133184
if __name__ == "__main__":
134185
test_rocm_cross_thread_reduction()
135186
test_rocm_inf_nan()
136187
test_rocm_reduction_binding()
137188
test_rocm_copy()
138189
test_rocm_vectorize_add()
190+
test_rocm_warp_shuffle()
191+
test_rocm_vectorized_exp()

0 commit comments

Comments
 (0)