@@ -216,35 +216,61 @@ const MMAConfig valid_mma_configs[] = {
216216/* !
217217 * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA
218218 * computation.
219- * \param mul The multiplicand data type.
220- * \param acc The accumulator data type.
219+ * \param dtype_a The data type of multiplicand a.
220+ * \param dtype_b The data type of multiplicand b.
221+ * \param dtype_c The data type of accumulator c.
221222 */
222- void CheckMMADTypeCompatible (DataType mul, DataType acc) {
223- switch (mul) {
223+ void CheckMMADTypeCompatible (DataType dtype_a, DataType dtype_b, DataType dtype_c) {
224+ std::string ab_not_match_err_str = " The multiplicands' data type " + DTypeToString (dtype_a) +
225+ DTypeToString (dtype_b) + " do not match." ;
226+ // check a and b
227+ switch (dtype_a) {
224228 case DataType::kBit1 :
229+ case DataType::kFloat16 :
230+ case DataType::kBFloat16 :
231+ case DataType::kTensorFloat32 :
232+ case DataType::kFloat64 :
233+ CHECK (dtype_a == dtype_b) << ab_not_match_err_str;
234+ break ;
225235 case DataType::kInt4 :
226236 case DataType::kUInt4 :
237+ CHECK (dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4 ) << ab_not_match_err_str;
238+ break ;
227239 case DataType::kInt8 :
228240 case DataType::kUInt8 :
229- CHECK (acc == DataType::kInt32 ) << " For multiplicand data type " << DTypeToString (mul)
230- << " , accumulator data type should be s32." ;
241+ CHECK (dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8 ) << ab_not_match_err_str;
242+ break ;
243+ default :
244+ CHECK (false ) << " Invalid multiplicand data types: " << DTypeToString (dtype_a)
245+ << DTypeToString (dtype_b);
246+ }
247+ // check a,b and c
248+ switch (dtype_a) {
249+ case DataType::kBit1 :
250+ case DataType::kInt4 :
251+ case DataType::kUInt4 :
252+ case DataType::kInt8 :
253+ case DataType::kUInt8 :
254+ CHECK (dtype_c == DataType::kInt32 )
255+ << " For multiplicand data type " << DTypeToString (dtype_a) << DTypeToString (dtype_b)
256+ << " , accumulator data type should be s32." ;
231257 break ;
232258 case DataType::kFloat16 :
233- CHECK (acc == DataType::kFloat16 || acc == DataType::kFloat32 )
259+ CHECK (dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32 )
234260 << " For multiplicand data type f16, accumulator data type should be f16/f32." ;
235261 break ;
236262 case DataType::kBFloat16 :
237263 case DataType::kTensorFloat32 :
238- CHECK (acc == DataType::kFloat32 )
239- << " For multiplicand data type bf16/tf32, accumulator data type can only be f32" ;
264+ CHECK (dtype_c == DataType::kFloat32 )
265+ << " For multiplicand data type bf16/tf32, accumulator data type can only be f32. " ;
240266 break ;
241267 case DataType::kFloat64 :
242- CHECK (acc == DataType::kFloat64 )
243- << " For multiplicand data type f64, accumulator data type can only be f64" ;
268+ CHECK (dtype_c == DataType::kFloat64 )
269+ << " For multiplicand data type f64, accumulator data type can only be f64. " ;
244270 break ;
245271 default :
246- CHECK (false ) << " Invalid multiplicand/accumulator data type pair : " << DTypeToString (mul )
247- << " , " << DTypeToString (acc ) << " ." ;
272+ CHECK (false ) << " Invalid multiplicand/accumulator data types : " << DTypeToString (dtype_a )
273+ << DTypeToString (dtype_b) << DTypeToString (dtype_c ) << " ." ;
248274 }
249275}
250276
@@ -272,10 +298,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
272298 if (use_bit_op) {
273299 CHECK (dtype_a == DataType::kBit1 ) << " Bit operator is only compatible with 1bit multiplicand." ;
274300 }
275- CHECK (dtype_a == dtype_b) << " The multiplicand data type must be equal, found "
276- << DTypeToString (dtype_a) << " and " << ptx::DTypeToString (dtype_b)
277- << " ." ;
278- CheckMMADTypeCompatible (dtype_a, dtype_c);
301+ CheckMMADTypeCompatible (dtype_a, dtype_b, dtype_c);
279302 if (saturate) {
280303 CHECK (dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 ||
281304 dtype_a == DataType::kUInt8 )
@@ -389,23 +412,26 @@ inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype)
389412 * \param m The M in mMnNkK of MMA instructions.
390413 * \param n The N in mMnNkK of MMA instructions.
391414 * \param k The K in mMnNkK of MMA instructions.
392- * \param dtype_mul The data type of multiplicand.
393- * \param dtype_acc The data type of accumulator.
415+ * \param dtype_a The data type of multiplicand a.
416+ * \param dtype_b The data type of multiplicand b.
417+ * \param dtype_c The data type of accumulator c.
394418 * \param sparse Whether it's Sparse MMA or not.
395419 */
396420inline std::tuple<std::string, std::string, std::string> GetMMAOperands (int m, int n, int k,
397- ptx::DataType dtype_mul,
398- ptx::DataType dtype_acc,
421+ ptx::DataType dtype_a,
422+ ptx::DataType dtype_b,
423+ ptx::DataType dtype_c,
399424 bool sparse) {
400425 std::stringstream templates, inputs, outputs;
401- const ptx::FragAttrs frag_attr_mul = ptx::GetFragAttrs (dtype_mul),
402- frag_attr_acc = ptx::GetFragAttrs (dtype_acc);
426+ const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs (dtype_a),
427+ frag_attr_b = ptx::GetFragAttrs (dtype_b),
428+ frag_attr_c = ptx::GetFragAttrs (dtype_c);
403429 constexpr uint32_t warp_size = 32 ;
404- const uint32_t threads = warp_size / GetNumMMAComputations (m, n, k, dtype_mul );
405- const int num_operands_a = (m * k) * ptx::DTypeBits (dtype_mul) / frag_attr_acc. size / threads /
406- (sparse ? 2 : 1 ),
407- num_operands_b = (k * n) * ptx::DTypeBits (dtype_mul ) / frag_attr_mul .size / threads,
408- num_operands_c = (m * n) * ptx::DTypeBits (dtype_acc ) / frag_attr_acc .size / threads;
430+ const uint32_t threads = warp_size / GetNumMMAComputations (m, n, k, dtype_a );
431+ const int num_operands_a =
432+ (m * k) * ptx::DTypeBits (dtype_a) / frag_attr_a. size / threads / (sparse ? 2 : 1 ),
433+ num_operands_b = (k * n) * ptx::DTypeBits (dtype_b ) / frag_attr_b .size / threads,
434+ num_operands_c = (m * n) * ptx::DTypeBits (dtype_c ) / frag_attr_c .size / threads;
409435
410436 // generate templates;
411437 int arg_counter = 0 ;
@@ -440,15 +466,14 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
440466 if (i != 0 ) {
441467 inputs << " , " ;
442468 }
443- inputs << " \" " << frag_attr_mul.reg_type << " \" ((" << frag_attr_mul.ptr_sig << " (A))[" << i
444- << " ])" ;
469+ inputs << " \" " << frag_attr_a.reg_type << " \" ((" << frag_attr_a.ptr_sig << " (A))[" << i << " ])" ;
445470 }
446471 for (int i = 0 ; i < num_operands_b; ++i) {
447- inputs << " , \" " << frag_attr_mul .reg_type << " \" ((" << frag_attr_mul .ptr_sig << " (B))[" << i
472+ inputs << " , \" " << frag_attr_b .reg_type << " \" ((" << frag_attr_b .ptr_sig << " (B))[" << i
448473 << " ])" ;
449474 }
450475 for (int i = 0 ; i < num_operands_c; ++i) {
451- inputs << " , \" " << frag_attr_acc .reg_type << " \" ((" << frag_attr_acc .ptr_sig << " (C))[" << i
476+ inputs << " , \" " << frag_attr_c .reg_type << " \" ((" << frag_attr_c .ptr_sig << " (C))[" << i
452477 << " ])" ;
453478 }
454479 // input of metadata for sparse mma.
@@ -461,7 +486,7 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
461486 if (i != 0 ) {
462487 outputs << " ," ;
463488 }
464- outputs << " \" =" << frag_attr_acc .reg_type << " \" ((" << frag_attr_acc .ptr_sig << " (D))[" << i
489+ outputs << " \" =" << frag_attr_c .reg_type << " \" ((" << frag_attr_c .ptr_sig << " (D))[" << i
465490 << " ])" ;
466491 }
467492 return std::make_tuple (templates.str (), inputs.str (), outputs.str ());
@@ -495,7 +520,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
495520)" ;
496521 std::string templates_str, inputs_str, outputs_str;
497522 std::tie (templates_str, inputs_str, outputs_str) =
498- GetMMAOperands (m, n, k, dtype_a, dtype_c, sparse);
523+ GetMMAOperands (m, n, k, dtype_a, dtype_b, dtype_c, sparse);
499524
500525 // replace patterns
501526 Replacer replacer;
0 commit comments