@@ -37,6 +37,11 @@ namespace ptx {
3737
3838/* !
3939 * \brief PTX data type.
40+ * \note
41+ * PTX fundamental data types:
42+ * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
43+ * PTX matrix data types:
44+ * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
4045 */
4146enum class DataType : int {
4247 kInt4 = 0 ,
@@ -173,6 +178,11 @@ struct MMAConfig {
173178 }
174179};
175180
181+ /* !
182+ * \brief Valid MMA configurations
183+ * \note Reference:
184+ * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
185+ */
176186const MMAConfig valid_mma_configs[] = {
177187 MMAConfig (8 , 8 , 4 , DataType::kFloat64 , false , false ),
178188 MMAConfig (8 , 8 , 4 , DataType::kFloat16 , false , false ),
@@ -219,6 +229,8 @@ const MMAConfig valid_mma_configs[] = {
219229 * \param dtype_a The data type of multiplicand a.
220230 * \param dtype_b The data type of multiplicand b.
221231 * \param dtype_c The data type of accumulator c.
232+ * \note Reference:
233+ * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
222234 */
223235void CheckMMADTypeCompatible (DataType dtype_a, DataType dtype_b, DataType dtype_c) {
224236 std::string ab_not_match_err_str = " The multiplicands' data type " + DTypeToString (dtype_a) +
@@ -296,7 +308,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
296308 << " Unrecognized 1-bit operation " << bit_op << " , can only be xor/and." ;
297309 bool use_bit_op = !bit_op.empty ();
298310 if (use_bit_op) {
299- CHECK (dtype_a == DataType::kBit1 ) << " Bit operator is only compatible with 1bit multiplicand." ;
311+ CHECK (dtype_a == DataType::kBit1 ) << " Bit operator is only compatible with 1-bit multiplicand." ;
300312 }
301313 CheckMMADTypeCompatible (dtype_a, dtype_b, dtype_c);
302314 if (saturate) {
@@ -328,14 +340,14 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
328340 */
329341class FragAttrs {
330342 public:
331- explicit FragAttrs (char reg_type, uint32_t size, std::string ptr_sig )
332- : reg_type(reg_type), size(size), ptr_sig(ptr_sig ) {}
343+ explicit FragAttrs (char reg_type, uint32_t size, std::string ptr_type )
344+ : reg_type(reg_type), size(size), ptr_type(ptr_type ) {}
333345 /* ! \brief PTX register type */
334346 char reg_type;
335347 /* ! \brief Fragment size */
336348 uint32_t size;
337- /* ! \brief Fragment pointer signature */
338- std::string ptr_sig ;
349+ /* ! \brief Fragment pointer type */
350+ std::string ptr_type ;
339351};
340352
341353/* !
@@ -466,14 +478,15 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
466478 if (i != 0 ) {
467479 inputs << " , " ;
468480 }
469- inputs << " \" " << frag_attr_a.reg_type << " \" ((" << frag_attr_a.ptr_sig << " (A))[" << i << " ])" ;
481+ inputs << " \" " << frag_attr_a.reg_type << " \" ((" << frag_attr_a.ptr_type << " (A))[" << i
482+ << " ])" ;
470483 }
471484 for (int i = 0 ; i < num_operands_b; ++i) {
472- inputs << " , \" " << frag_attr_b.reg_type << " \" ((" << frag_attr_b.ptr_sig << " (B))[" << i
485+ inputs << " , \" " << frag_attr_b.reg_type << " \" ((" << frag_attr_b.ptr_type << " (B))[" << i
473486 << " ])" ;
474487 }
475488 for (int i = 0 ; i < num_operands_c; ++i) {
476- inputs << " , \" " << frag_attr_c.reg_type << " \" ((" << frag_attr_c.ptr_sig << " (C))[" << i
489+ inputs << " , \" " << frag_attr_c.reg_type << " \" ((" << frag_attr_c.ptr_type << " (C))[" << i
477490 << " ])" ;
478491 }
479492 // input of metadata for sparse mma.
@@ -486,7 +499,7 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
486499 if (i != 0 ) {
487500 outputs << " ," ;
488501 }
489- outputs << " \" =" << frag_attr_c.reg_type << " \" ((" << frag_attr_c.ptr_sig << " (D))[" << i
502+ outputs << " \" =" << frag_attr_c.reg_type << " \" ((" << frag_attr_c.ptr_type << " (D))[" << i
490503 << " ])" ;
491504 }
492505 return std::make_tuple (templates.str (), inputs.str (), outputs.str ());
@@ -512,7 +525,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
512525 std::string asm_code = R"(
513526 {
514527 __asm__ __volatile__(
515- "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{satinite }{dtype}{atype}{btype}{ctype}{1bit }"
528+ "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate }{dtype}{atype}{btype}{ctype}{bitop }"
516529 "{templates};\n"
517530 : {outputs}
518531 : {inputs});
@@ -526,14 +539,14 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
526539 Replacer replacer;
527540 replacer.register_rule (" {sparse}" , sparse ? " .sp" : " " );
528541 replacer.register_rule (" {shape}" , shape);
529- replacer.register_rule (" {satinite }" , saturate ? " .satfinite" : " " );
542+ replacer.register_rule (" {saturate }" , saturate ? " .satfinite" : " " );
530543 replacer.register_rule (" {alayout}" , A_layout);
531544 replacer.register_rule (" {blayout}" , B_layout);
532545 replacer.register_rule (" {atype}" , ptx::DTypeToString (dtype_a));
533546 replacer.register_rule (" {btype}" , ptx::DTypeToString (dtype_b));
534547 replacer.register_rule (" {ctype}" , ptx::DTypeToString (dtype_c));
535548 replacer.register_rule (" {dtype}" , ptx::DTypeToString (dtype_c));
536- replacer.register_rule (" {1bit }" , bit_op.empty () ? " " : " ." + bit_op + " .popc" );
549+ replacer.register_rule (" {bitop }" , bit_op.empty () ? " " : " ." + bit_op + " .popc" );
537550 replacer.register_rule (" {templates}" , templates_str);
538551 replacer.register_rule (" {outputs}" , outputs_str);
539552 replacer.register_rule (" {inputs}" , inputs_str);
0 commit comments