Skip to content

Commit 86a3c68

Browse files
yzh119pfk-beta
authored andcommitted
[PTX] Support mma.sp to use Sparse Tensor Cores and refactor mma codegen (apache#10339)
* init * upd * upd * lint * lint again * upd * add m16n8k32 testcase * format * use make_tuple instead of initializer list * add metadata offset * upd * docstring and sanity * add u8s8s32 back * improvement * compatible apache#9727
1 parent 060bdee commit 86a3c68

File tree

7 files changed

+934
-1314
lines changed

7 files changed

+934
-1314
lines changed

include/tvm/tir/builtin.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,19 @@ TVM_DLL const Op& tvm_store_matrix_sync();
596596
*/
597597
TVM_DLL const Op& ptx_mma();
598598

599+
/*!
600+
* \brief tvm intrinsic for sparse tensor core ptx instructions.
601+
*
602+
* void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout,
603+
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
604+
* Var multiplicand_a, Expr a_index,
605+
* Var multiplicand_b, Expr b_index,
606+
* Var accumulator, Expr c_index,
607+
* Var metadata, Expr meta_index,
608+
* Var sparse_selector, bool saturate);
609+
*/
610+
TVM_DLL const Op& ptx_mma_sp();
611+
599612
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
600613
/*!
601614
* \brief Get the high level half of the vector

src/target/source/codegen_cuda.cc

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
744744
// arg 10: C accumulator
745745
// arg 11: C accumulator index
746746
// arg 12: saturate
747-
ICHECK_EQ(op->args.size(), 13U);
747+
// arg 13: (optional) 1-bit operator (xor or and)
748+
ICHECK(op->args.size() == 13U || op->args.size() == 14U);
748749
std::string shape = Downcast<StringImm>(op->args[0])->value;
749750
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
750751
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
@@ -757,11 +758,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
757758
std::string b_bias = this->PrintExpr(op->args[9]);
758759
std::string c_ref = this->PrintExpr(op->args[10]);
759760
std::string c_bias = this->PrintExpr(op->args[11]);
760-
bool saturate = (Downcast<IntImm>(op->args[12])->value != 0);
761-
std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
762-
a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);
761+
bool saturate = Downcast<Bool>(op->args[12])->value;
762+
std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
763+
std::string asm_code =
764+
PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
765+
b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
763766

764767
this->stream << asm_code;
768+
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
769+
// arg 0: shape: mXnXkX
770+
// arg 1: A layout: row/col
771+
// arg 2: B layout: row/col
772+
// arg 3: A precision: fp16, fp32, ...
773+
// arg 4: B precision: fp16, fp32, ...
774+
// arg 5: C precision: fp16, fp32, ...
775+
// arg 6: A multiplicand
776+
// arg 7: A multiplicand index
777+
// arg 8: B multiplicand
778+
// arg 9: B multiplicand index
779+
// arg 10: C accumulator
780+
// arg 11: C accumulator index
781+
// arg 12: metadata
782+
// arg 13: metadata index
783+
// arg 14: sparse_selector
784+
// arg 15: saturate
785+
ICHECK_EQ(op->args.size(), 16U);
786+
std::string shape = Downcast<StringImm>(op->args[0])->value;
787+
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
788+
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
789+
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
790+
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
791+
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
792+
std::string a_ref = this->PrintExpr(op->args[6]);
793+
std::string a_offset = this->PrintExpr(op->args[7]);
794+
std::string b_ref = this->PrintExpr(op->args[8]);
795+
std::string b_offset = this->PrintExpr(op->args[9]);
796+
std::string c_ref = this->PrintExpr(op->args[10]);
797+
std::string c_offset = this->PrintExpr(op->args[11]);
798+
std::string metadata = this->PrintExpr(op->args[12]);
799+
std::string metadata_offset = this->PrintExpr(op->args[13]);
800+
std::string sparse_selector = this->PrintExpr(op->args[14]);
801+
bool saturate = Downcast<Bool>(op->args[15])->value;
802+
std::string asm_code = PrintMMAAssembly(
803+
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
804+
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
805+
this->stream << asm_code;
765806
} else {
766807
CodeGenC::VisitExpr_(op, os);
767808
}

0 commit comments

Comments
 (0)