@@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
645645 : "l"((void *)({smem_addr}))
646646 );
647647 __asm__ __volatile__(
648- "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
649- :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
648+ #if TVM_ENABLE_L2_PREFETCH
649+ "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
650+ #else
651+ "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
652+ #endif
653+ :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
650654 );
651655 }
652656)" ;
@@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
665669 const std::string& global_elem_offset,
666670 const std::string& bytes,
667671 const std::string& predicate_value) {
672+ CHECK (bytes == " 16" || bytes == " 12" || bytes == " 8" || bytes == " 4" || bytes == " 2" ||
673+ bytes == " 1" )
674+ << " Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async" ;
668675 std::string predicated_asm_code = R"(
669676 {
670677 unsigned int addr;
671678 __asm__ __volatile__(
672- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n "
679+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
673680 : "=r"(addr)
674681 : "l"((void *)({smem_addr}))
675682 );
676- int src_bytes = {pred_guard} ? {bytes} : 0 ;
683+ int pred_guard = (int) {pred_guard};
677684 __asm__ __volatile__(
678- "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
679- :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
685+ "{ .reg .pred p;"
686+ " setp.ne.b32 p, %0, 0;"
687+ #if TVM_ENABLE_L2_PREFETCH
688+ " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
689+ #else
690+ " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
691+ #endif
692+ " @!p {store_shared};}"
693+ :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg}
680694 );
681695 }
682696)" ;
697+ auto [store_shared, nopreg] = [](const std::string& bytes) {
698+ if (bytes == " 16" )
699+ return std::make_tuple (" st.shared.v4.u32 [%1], {%4, %5, %6, %7}" ,
700+ " \" r\" (0), \" r\" (0), \" r\" (0),\" r\" (0)" );
701+ else if (bytes == " 12" )
702+ return std::make_tuple (" st.shared.v3.u32 [%1], {%4, %5, %6}" , " \" r\" (0), \" r\" (0), \" r\" (0)" );
703+ else if (bytes == " 8" )
704+ return std::make_tuple (" st.shared.v2.u32 [%1], {%4, %5}" , " \" r\" (0), \" r\" (0)" );
705+ else if (bytes == " 4" )
706+ return std::make_tuple (" st.shared.u32 [%1], {%4}" , " \" r\" (0)" );
707+ else if (bytes == " 2" )
708+ return std::make_tuple (" st.shared.u16 [%1], {%4}" , " \" r\" (0)" );
709+ else if (bytes == " 1" )
710+ return std::make_tuple (" st.shared.u8 [%1], {%4}" , " \" r\" (0)" );
711+ else
712+ return std::make_tuple (" " ," " );
713+ }(bytes);
714+
683715 Replacer replacer;
684716 replacer.register_rule (" {smem_addr}" , shared_ptr + " + " + shared_elem_offset);
685717 replacer.register_rule (" {global_ptr}" , global_ptr + " + " + global_elem_offset);
686718 replacer.register_rule (" {bytes}" , bytes);
687719 replacer.register_rule (" {cg_or_ca}" , bytes == " 16" ? " cg" : " ca" );
720+ replacer.register_rule (" {store_shared}" , store_shared);
721+ replacer.register_rule (" {nopreg}" , nopreg);
688722 replacer.register_rule (" {pred_guard}" , predicate_value);
689723 predicated_asm_code = replacer.rewrite (predicated_asm_code);
690724 return predicated_asm_code;
0 commit comments