Skip to content

Commit 08b366d

Browse files
committed
add efficent cuda support for vectorized if_then_else
1 parent b724c87 commit 08b366d

File tree

4 files changed

+265
-24
lines changed

4 files changed

+265
-24
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() {
134134
decl_stream << "#include <mma.h>\n";
135135
}
136136

137+
decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
138+
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
139+
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
140+
decl_stream << "#else\n";
141+
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
142+
decl_stream << "#endif\n";
143+
137144
decl_stream << "\n#ifdef _WIN32\n";
138145
decl_stream << " using uint = unsigned int;\n";
139146
decl_stream << " using uchar = unsigned char;\n";

src/target/source/ptx.cc

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
645645
: "l"((void *)({smem_addr}))
646646
);
647647
__asm__ __volatile__(
648-
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
649-
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
648+
#if TVM_ENABLE_L2_PREFETCH
649+
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
650+
#else
651+
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
652+
#endif
653+
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
650654
);
651655
}
652656
)";
@@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
665669
const std::string& global_elem_offset,
666670
const std::string& bytes,
667671
const std::string& predicate_value) {
672+
CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" ||
673+
bytes == "1")
674+
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
668675
std::string predicated_asm_code = R"(
669676
{
670677
unsigned int addr;
671678
__asm__ __volatile__(
672-
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
679+
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
673680
: "=r"(addr)
674681
: "l"((void *)({smem_addr}))
675682
);
676-
int src_bytes = {pred_guard} ? {bytes} : 0;
683+
int pred_guard = (int){pred_guard};
677684
__asm__ __volatile__(
678-
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
679-
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
685+
"{ .reg .pred p;"
686+
" setp.ne.b32 p, %0, 0;"
687+
#if TVM_ENABLE_L2_PREFETCH
688+
" @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
689+
#else
690+
" @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
691+
#endif
692+
" @!p {store_shared};}"
693+
:: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg}
680694
);
681695
}
682696
)";
697+
auto [store_shared, nopreg] = [](const std::string& bytes) {
698+
if (bytes == "16")
699+
return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}",
700+
"\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)");
701+
else if (bytes == "12")
702+
return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)");
703+
else if (bytes == "8")
704+
return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)");
705+
else if (bytes == "4")
706+
return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)");
707+
else if (bytes == "2")
708+
return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)");
709+
else if (bytes == "1")
710+
return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)");
711+
else
712+
return std::make_tuple("","");
713+
}(bytes);
714+
683715
Replacer replacer;
684716
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
685717
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
686718
replacer.register_rule("{bytes}", bytes);
687719
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
720+
replacer.register_rule("{store_shared}", store_shared);
721+
replacer.register_rule("{nopreg}", nopreg);
688722
replacer.register_rule("{pred_guard}", predicate_value);
689723
predicated_asm_code = replacer.rewrite(predicated_asm_code);
690724
return predicated_asm_code;

src/tir/transforms/inject_ptx_async_copy.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator {
112112
}
113113
return PrimExpr();
114114
}();
115-
116115
if (src_offset.defined() && dst_offset.defined()) {
117116
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
118117
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
119118
load->buffer->data, src_offset, PrimExpr(bytes)}));
120119
}
120+
} else {
121+
// Only some vectorized indexing patterns are supported for now.
122+
auto src_offset = [=]() -> PrimExpr {
123+
if (load->indices[0]->IsInstance<RampNode>()) {
124+
return load->indices[0].as<RampNode>()->base;
125+
}
126+
return PrimExpr();
127+
}();
128+
129+
auto dst_offset = [=]() -> PrimExpr {
130+
if (store->indices[0].as<RampNode>()) {
131+
return store->indices[0].as<RampNode>()->base;
132+
} else if (store->indices[0].as<AddNode>()) {
133+
// The case where the dst buffer is a byte buffer generated by merging dynamic
134+
// shared memory.
135+
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
136+
auto* add = store->indices[0].as<AddNode>();
137+
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
138+
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
139+
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
140+
}
141+
return PrimExpr();
142+
}();
143+
144+
if (src_offset.defined() && dst_offset.defined()) {
145+
return Evaluate(
146+
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
147+
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
148+
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
149+
}
121150
}
122151
}
123152
}

0 commit comments

Comments
 (0)