|
34 | 34 |
|
35 | 35 | #include "literal/cuda_half_t.h" |
36 | 36 | #include "ptx.h" |
| 37 | +#include "tvm/arith/iter_affine_map.h" |
37 | 38 |
|
38 | 39 | namespace tvm { |
39 | 40 | namespace codegen { |
@@ -838,21 +839,38 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { |
838 | 839 | std::string dst = this->PrintExpr(op->args[2]); |
839 | 840 | std::string src = this->PrintExpr(op->args[3]); |
840 | 841 | std::string src_offset = this->PrintExpr(op->args[4]); |
841 | | - std::string stride = this->PrintExpr(op->args[5]); |
842 | 842 |
|
843 | 843 | if (m == 16 && n == 8) { |
| 844 | + std::string stride = this->PrintExpr(op->args[5]); |
844 | 845 | os << "for (int i = 0; i < 4; ++i) {\n"; |
845 | 846 | os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride |
846 | 847 | << " + (threadIdx.x % 4) * 2 + i % 2]" |
847 | 848 | << " = " << src << "[" << src_offset << " + i];\n"; |
848 | 849 | os << "}\n"; |
849 | 850 | } else if (m == 16 && n == 16) { |
850 | | - os << "for (int outer = 0; outer < 2; ++outer) {\n"; |
851 | | - os << "for (int i = 0; i < 4; ++i) {\n"; |
852 | | - os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride |
853 | | - << " + outer * 8 + (threadIdx.x % 4) * 2 + i % 2]" |
854 | | - << " = " << src << "[" << src_offset << " + outer * 4 + i];\n"; |
855 | | - os << "}\n"; |
| 851 | + const auto* index_map = |
| 852 | + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); |
| 853 | + ICHECK(index_map); |
| 854 | + |
| 855 | + Var var_i("i"); |
| 856 | + Var var_j("j"); |
| 857 | + Array<PrimExpr> forward_map = (*index_map)(var_i, var_j); |
| 858 | + |
| 859 | + arith::Analyzer ana; |
| 860 | + auto iter_map = arith::DetectIterMap( |
| 861 | + forward_map, {{var_i, Range(0, 16)}, {var_j, Range(0, 16)}}, true, true, &ana, true); |
| 862 | + |
| 863 | + Var thread_id("threadIdx.x"); |
| 864 | + Var local_id("local_id"); |
| 865 | + auto inverse_map = arith::InverseAffineIterMap(iter_map, {thread_id, local_id}); |
| 866 | + PrimExpr stride = op->args[5]; |
| 867 | + auto dst_idx = inverse_map[var_i] * stride + inverse_map[var_j]; |
| 868 | + |
| 869 | + var_idmap_[thread_id.get()] = "threadIdx.x"; |
| 870 | + var_idmap_[local_id.get()] = "local_id"; |
| 871 | + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; |
| 872 | + os << dst << "[" + this->PrintExpr(dst_idx) + "]" |
| 873 | + << " = " << src << "[" << src_offset << " + local_id];\n"; |
856 | 874 | os << "}\n"; |
857 | 875 | } |
858 | 876 | } else if (op->op.same_as(builtin::mma_fill())) { |
|
0 commit comments