Skip to content

Commit 89536ff

Browse files
committed
improvement
1 parent ca92032 commit 89536ff

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

src/target/source/ptx_mma.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ namespace ptx {
3737

3838
/*!
3939
* \brief PTX data type.
40+
* \note
41+
* PTX fundamental data types:
42+
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
43+
* PTX matrix data types:
44+
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
4045
*/
4146
enum class DataType : int {
4247
kInt4 = 0,
@@ -173,6 +178,11 @@ struct MMAConfig {
173178
}
174179
};
175180

181+
/*!
182+
* \brief Valid MMA configurations
183+
* \note Reference:
184+
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
185+
*/
176186
const MMAConfig valid_mma_configs[] = {
177187
MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
178188
MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
@@ -219,6 +229,8 @@ const MMAConfig valid_mma_configs[] = {
219229
* \param dtype_a The data type of multiplicand a.
220230
* \param dtype_b The data type of multiplicand b.
221231
* \param dtype_c The data type of accumulator c.
232+
* \note Reference:
233+
* https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
222234
*/
223235
void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) {
224236
std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) +
@@ -296,7 +308,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
296308
<< "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and.";
297309
bool use_bit_op = !bit_op.empty();
298310
if (use_bit_op) {
299-
CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1bit multiplicand.";
311+
CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand.";
300312
}
301313
CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
302314
if (saturate) {
@@ -328,14 +340,14 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType
328340
*/
329341
class FragAttrs {
330342
public:
331-
explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_sig)
332-
: reg_type(reg_type), size(size), ptr_sig(ptr_sig) {}
343+
explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
344+
: reg_type(reg_type), size(size), ptr_type(ptr_type) {}
333345
/*! \brief PTX register type */
334346
char reg_type;
335347
/*! \brief Fragment size */
336348
uint32_t size;
337-
/*! \brief Fragment pointer signature */
338-
std::string ptr_sig;
349+
/*! \brief Fragment pointer type */
350+
std::string ptr_type;
339351
};
340352

341353
/*!
@@ -466,14 +478,15 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
466478
if (i != 0) {
467479
inputs << ", ";
468480
}
469-
inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_sig << "(A))[" << i << "])";
481+
inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i
482+
<< "])";
470483
}
471484
for (int i = 0; i < num_operands_b; ++i) {
472-
inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_sig << "(B))[" << i
485+
inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i
473486
<< "])";
474487
}
475488
for (int i = 0; i < num_operands_c; ++i) {
476-
inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(C))[" << i
489+
inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i
477490
<< "])";
478491
}
479492
// input of metadata for sparse mma.
@@ -486,7 +499,7 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
486499
if (i != 0) {
487500
outputs << ",";
488501
}
489-
outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_sig << "(D))[" << i
502+
outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i
490503
<< "])";
491504
}
492505
return std::make_tuple(templates.str(), inputs.str(), outputs.str());
@@ -512,7 +525,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
512525
std::string asm_code = R"(
513526
{
514527
__asm__ __volatile__(
515-
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{satinite}{dtype}{atype}{btype}{ctype}{1bit}"
528+
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
516529
"{templates};\n"
517530
: {outputs}
518531
: {inputs});
@@ -526,14 +539,14 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
526539
Replacer replacer;
527540
replacer.register_rule("{sparse}", sparse ? ".sp" : "");
528541
replacer.register_rule("{shape}", shape);
529-
replacer.register_rule("{satinite}", saturate ? ".satfinite" : "");
542+
replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
530543
replacer.register_rule("{alayout}", A_layout);
531544
replacer.register_rule("{blayout}", B_layout);
532545
replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
533546
replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
534547
replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
535548
replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
536-
replacer.register_rule("{1bit}", bit_op.empty() ? "" : "." + bit_op + ".popc");
549+
replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
537550
replacer.register_rule("{templates}", templates_str);
538551
replacer.register_rule("{outputs}", outputs_str);
539552
replacer.register_rule("{inputs}", inputs_str);

0 commit comments

Comments
 (0)