Skip to content

Commit 9ad568f

Browse files
committed
Added block info gathering support to IndexTreeToSCF #20
1 parent 9a1a276 commit 9ad568f

File tree

5 files changed

+203
-28
lines changed

5 files changed

+203
-28
lines changed

first.mlir

+107-22
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,115 @@
11
module {
22
func.func @main() {
3+
%cst = arith.constant 1.000000e+00 : f64
4+
%cst_0 = arith.constant 0.000000e+00 : f64
5+
%c10 = arith.constant 10 : index
6+
%c9 = arith.constant 9 : index
7+
%c8 = arith.constant 8 : index
8+
%c7 = arith.constant 7 : index
9+
%c6 = arith.constant 6 : index
10+
%c5 = arith.constant 5 : index
11+
%c1_i32 = arith.constant 1 : i32
12+
%c0_i32 = arith.constant 0 : i32
13+
%c3 = arith.constant 3 : index
14+
%c2 = arith.constant 2 : index
315
%c0 = arith.constant 0 : index
416
%c1 = arith.constant 1 : index
517
%c4 = arith.constant 4 : index
6-
%0 = "ta.static_index_label"(%c0, %c4, %c1) : (index, index, index) -> !ta.range
7-
%c0_0 = arith.constant 0 : index
8-
%c1_1 = arith.constant 1 : index
9-
%1 = "ta.dynamic_index_label"(%c0_0, %c1_1) : (index, index) -> !ta.range
10-
%c0_2 = arith.constant 0 : index
11-
%c1_3 = arith.constant 1 : index
12-
%2 = "ta.dynamic_index_label"(%c0_2, %c1_3) : (index, index) -> !ta.range
13-
%3 = "ta.sparse_tensor_decl"(%1, %2) {format = "CSR", temporal_tensor = false} : (!ta.range, !ta.range) -> tensor<?x?xf64>
14-
%4 = "ta.dense_tensor_decl"(%0, %1) {format = "Dense"} : (!ta.range, !ta.range) -> tensor<4x?xf64>
15-
%5 = "ta.dense_tensor_decl"(%0, %2) {format = "Dense"} : (!ta.range, !ta.range) -> tensor<4x?xf64>
16-
"ta.fill"(%4) {value = 1.000000e+00 : f64} : (tensor<4x?xf64>) -> ()
17-
"ta.fill_from_file"(%3) {filename = "SPARSE_FILE_NAME0", readMode = 1 : i32} : (tensor<?x?xf64>) -> ()
18-
"ta.fill"(%5) {value = 0.000000e+00 : f64} : (tensor<4x?xf64>) -> ()
19-
%6 = "it.ComputeRHS"(%4, %3) {allBlocks = [["UNK", "UNK"], ["UNK", "UNK"]], allFormats = [["D", "D"], ["D", "CU"]], allPerms = [[0, 1], [1, 2]]} : (tensor<4x?xf64>, tensor<?x?xf64>) -> tensor<*xf64>
20-
%7 = "it.ComputeLHS"(%5) {allBlocks = [["UNK", "UNK"]], allFormats = [["D", "D"]], allPerms = [[0, 2]]} : (tensor<4x?xf64>) -> tensor<*xf64>
21-
%8 = "it.Compute"(%6, %7) {MaskType = "none", comp_worksp_opt = false, semiring = "plusxy_times"} : (tensor<*xf64>, tensor<*xf64>) -> i64
22-
%9 = "it.Indices"(%8) {indices = [2]} : (i64) -> i64
23-
%10 = "it.Indices"(%9) {indices = [1]} : (i64) -> i64
24-
%11 = "it.Indices"(%10) {indices = [0]} : (i64) -> i64
25-
%12 = "it.itree"(%11) : (i64) -> i64
26-
"ta.print"(%5) : (tensor<4x?xf64>) -> ()
27-
"ta.print"(%3) : (tensor<?x?xf64>) -> ()
18+
%alloc = memref.alloc() : memref<13xindex>
19+
%cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
20+
call @read_input_sizes_2D_f64(%c0_i32, %c0, %c0, %c2, %c0, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
21+
%0 = memref.load %alloc[%c0] : memref<13xindex>
22+
%1 = memref.load %alloc[%c1] : memref<13xindex>
23+
%2 = memref.load %alloc[%c2] : memref<13xindex>
24+
%3 = memref.load %alloc[%c3] : memref<13xindex>
25+
%4 = memref.load %alloc[%c4] : memref<13xindex>
26+
%5 = memref.load %alloc[%c5] : memref<13xindex>
27+
%6 = memref.load %alloc[%c6] : memref<13xindex>
28+
%7 = memref.load %alloc[%c7] : memref<13xindex>
29+
%8 = memref.load %alloc[%c8] : memref<13xindex>
30+
%9 = memref.load %alloc[%c9] : memref<13xindex>
31+
%10 = memref.load %alloc[%c10] : memref<13xindex>
32+
%alloc_1 = memref.alloc(%0) : memref<?xindex>
33+
scf.for %arg0 = %c0 to %0 step %c1 {
34+
memref.store %c0, %alloc_1[%arg0] : memref<?xindex>
35+
}
36+
%cast_2 = memref.cast %alloc_1 : memref<?xindex> to memref<*xindex>
37+
%alloc_3 = memref.alloc(%1) : memref<?xindex>
38+
scf.for %arg0 = %c0 to %1 step %c1 {
39+
memref.store %c0, %alloc_3[%arg0] : memref<?xindex>
40+
}
41+
%cast_4 = memref.cast %alloc_3 : memref<?xindex> to memref<*xindex>
42+
%alloc_5 = memref.alloc(%2) : memref<?xindex>
43+
scf.for %arg0 = %c0 to %2 step %c1 {
44+
memref.store %c0, %alloc_5[%arg0] : memref<?xindex>
45+
}
46+
%cast_6 = memref.cast %alloc_5 : memref<?xindex> to memref<*xindex>
47+
%alloc_7 = memref.alloc(%3) : memref<?xindex>
48+
scf.for %arg0 = %c0 to %3 step %c1 {
49+
memref.store %c0, %alloc_7[%arg0] : memref<?xindex>
50+
}
51+
%cast_8 = memref.cast %alloc_7 : memref<?xindex> to memref<*xindex>
52+
%alloc_9 = memref.alloc(%4) : memref<?xindex>
53+
scf.for %arg0 = %c0 to %4 step %c1 {
54+
memref.store %c0, %alloc_9[%arg0] : memref<?xindex>
55+
}
56+
%cast_10 = memref.cast %alloc_9 : memref<?xindex> to memref<*xindex>
57+
%alloc_11 = memref.alloc(%5) : memref<?xindex>
58+
scf.for %arg0 = %c0 to %5 step %c1 {
59+
memref.store %c0, %alloc_11[%arg0] : memref<?xindex>
60+
}
61+
%cast_12 = memref.cast %alloc_11 : memref<?xindex> to memref<*xindex>
62+
%alloc_13 = memref.alloc(%6) : memref<?xindex>
63+
scf.for %arg0 = %c0 to %6 step %c1 {
64+
memref.store %c0, %alloc_13[%arg0] : memref<?xindex>
65+
}
66+
%cast_14 = memref.cast %alloc_13 : memref<?xindex> to memref<*xindex>
67+
%alloc_15 = memref.alloc(%7) : memref<?xindex>
68+
scf.for %arg0 = %c0 to %7 step %c1 {
69+
memref.store %c0, %alloc_15[%arg0] : memref<?xindex>
70+
}
71+
%cast_16 = memref.cast %alloc_15 : memref<?xindex> to memref<*xindex>
72+
%alloc_17 = memref.alloc(%8) : memref<?xf64>
73+
scf.for %arg0 = %c0 to %8 step %c1 {
74+
memref.store %cst_0, %alloc_17[%arg0] : memref<?xf64>
75+
}
76+
%cast_18 = memref.cast %alloc_17 : memref<?xf64> to memref<*xf64>
77+
call @read_input_2D_f64(%c0_i32, %c0, %c0, %c2, %c0, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
78+
%alloc_19 = memref.alloc(%9) {alignment = 32 : i64} : memref<4x?xf64>
79+
%alloc_20 = memref.alloc(%10) {alignment = 32 : i64} : memref<4x?xf64>
80+
linalg.fill ins(%cst : f64) outs(%alloc_19 : memref<4x?xf64>)
81+
linalg.fill ins(%cst_0 : f64) outs(%alloc_20 : memref<4x?xf64>)
82+
scf.for %arg0 = %c0 to %c4 step %c1 {
83+
scf.for %arg1 = %c0 to %9 step %c1 {
84+
%11 = memref.load %alloc_9[%c0] : memref<?xindex>
85+
%12 = memref.load %alloc_9[%c1] : memref<?xindex>
86+
scf.for %arg2 = %11 to %12 step %c1 {
87+
%13 = memref.load %alloc_11[%arg2] : memref<?xindex>
88+
%14 = memref.load %alloc_19[%arg0, %arg1] : memref<4x?xf64>
89+
%15 = memref.load %alloc_17[%arg2] : memref<?xf64>
90+
%16 = memref.load %alloc_20[%arg0, %13] : memref<4x?xf64>
91+
%17 = arith.mulf %14, %15 : f64
92+
%18 = arith.addf %16, %17 : f64
93+
memref.store %18, %alloc_20[%arg0, %13] : memref<4x?xf64>
94+
}
95+
}
96+
}
97+
%cast_21 = memref.cast %alloc_20 : memref<4x?xf64> to memref<*xf64>
98+
call @comet_print_memref_f64(%cast_21) : (memref<*xf64>) -> ()
99+
call @comet_print_memref_i64(%cast_2) : (memref<*xindex>) -> ()
100+
call @comet_print_memref_i64(%cast_4) : (memref<*xindex>) -> ()
101+
call @comet_print_memref_i64(%cast_6) : (memref<*xindex>) -> ()
102+
call @comet_print_memref_i64(%cast_8) : (memref<*xindex>) -> ()
103+
call @comet_print_memref_i64(%cast_10) : (memref<*xindex>) -> ()
104+
call @comet_print_memref_i64(%cast_12) : (memref<*xindex>) -> ()
105+
call @comet_print_memref_i64(%cast_14) : (memref<*xindex>) -> ()
106+
call @comet_print_memref_i64(%cast_16) : (memref<*xindex>) -> ()
107+
call @comet_print_memref_f64(%cast_18) : (memref<*xf64>) -> ()
28108
return
29109
}
110+
func.func private @read_input_2D_f64(i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32)
111+
func.func private @read_input_sizes_2D_f64(i32, index, index, index, index, memref<*xindex>, i32)
112+
func.func private @comet_sort_index(memref<*xindex>, index, index)
113+
func.func private @comet_print_memref_f64(memref<*xf64>)
114+
func.func private @comet_print_memref_i64(memref<*xindex>)
30115
}

first.ta

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def main() {
55
IndexLabel [c] = [?];
66

77
#Tensor Declarations
8-
Tensor<double> B([b, c], {CSR}); #sparse tensor declarations should be before dense tensor declarations
8+
Tensor<double> B([b, c], {BCSR}); #sparse tensor declarations should be before dense tensor declarations
99
Tensor<double> A([a, b], {Dense});
1010
Tensor<double> C([a, c], {Dense});
1111

include/comet/Dialect/Utils/Utils.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ namespace mlir
211211
void getFormatsOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
212212
void getRHSFormatsOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
213213
void getLHSFormatsOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
214+
215+
/// TODO(patrick): Do we want to merge these with the getFormat functions above?
216+
void getBlocksOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
217+
void getRHSBlocksOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
218+
void getLHSBlocksOfComputeOp(Value computeOp, std::vector<std::vector<std::string>> &opFormats);
214219

215220
void getFormatsPermsOfComputeOp(Value computeOp,
216221
std::vector<std::vector<std::string>> &opFormats,
@@ -223,7 +228,8 @@ namespace mlir
223228
std::vector<Value> &leafs,
224229
std::vector<Value> &tensors,
225230
std::vector<unsigned int> &ids,
226-
std::vector<std::string> &formats);
231+
std::vector<std::string> &formats,
232+
std::vector<std::string> &blocks);
227233

228234
void replaceOperands(Operation *itComputeOp, std::vector<Value> newComputeOps);
229235

lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp

+36-3
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ namespace
10441044
void genForOps(std::vector<Value> &tensors,
10451045
std::vector<unsigned int> &ids,
10461046
std::vector<std::string> &formats,
1047+
std::vector<std::string> &blocks,
10471048
indexTree::IndexTreeOp rootOp,
10481049
OpBuilder &builder,
10491050
OpsTree *opstree,
@@ -2383,6 +2384,7 @@ namespace
23832384
std::vector<std::vector<Value>> &tensors_lhs_Allocs /* output */,
23842385
std::vector<std::vector<Value>> &tensors_rhs_Allocs /* output */,
23852386
std::vector<std::vector<std::string>> &allFormats /*output*/,
2387+
std::vector<std::vector<std::string>> &allBlocks /*output*/,
23862388
std::vector<std::vector<int>> &allPerms /* output */,
23872389
std::vector<std::vector<int>> &allPerms_rhs /* output */,
23882390
std::vector<Value> &main_tensors_all /* output */,
@@ -2452,6 +2454,18 @@ namespace
24522454
}
24532455
comet_debug() << "\n";
24542456
}
2457+
2458+
getBlocksOfComputeOp(cur_op.getOperation()->getResult(0), allBlocks);
2459+
comet_debug() << " allBlocks: \n";
2460+
for (auto m : allBlocks)
2461+
{
2462+
comet_debug() << " ";
2463+
for (auto n : m)
2464+
{
2465+
comet_debug() << n << " ";
2466+
}
2467+
comet_debug() << "\n";
2468+
}
24552469

24562470
comet_debug() << " ";
24572471
comet_vdump(cur_op);
@@ -2510,6 +2524,7 @@ namespace
25102524
int main_tensor_nums,
25112525
std::vector<std::vector<int>> &allPerms,
25122526
std::vector<std::vector<std::string>> &allFormats,
2527+
std::vector<std::vector<std::string>> &allBlocks,
25132528
std::vector<Value> &main_tensors_all,
25142529
std::vector<scf::ForOp> &nested_forops,
25152530
std::vector<Value> &nested_AccessIdx,
@@ -2529,6 +2544,7 @@ namespace
25292544
comet_debug() << " index_loc " << index_loc << "\n";
25302545
comet_debug() << " Perm: " << allPerms[i][j] << "\n";
25312546
comet_debug() << " Format: " << allFormats[i][j] << "\n";
2547+
comet_debug() << " Block: " << allBlocks[i][j] << "\n";
25322548
assert(index_loc < nested_forops.size() &&
25332549
"index_loc < nested_forops.size(), i.e. the index not exist in nested for loop\n");
25342550
allLoopsArg[i].push_back(nested_forops[index_loc].getInductionVar());
@@ -3705,6 +3721,7 @@ namespace
37053721
std::vector<std::vector<Value>> tensors_lhs_Allocs;
37063722
std::vector<std::vector<Value>> tensors_rhs_Allocs;
37073723
std::vector<std::vector<std::string>> allFormats;
3724+
std::vector<std::vector<std::string>> allBlocks;
37083725
std::vector<std::vector<int>> allPerms;
37093726
std::vector<std::vector<int>> allPerms_rhs;
37103727
std::vector<Value> main_tensors_all; /// main_tensors_all has first RHS tensors then LHS tensors
@@ -3714,6 +3731,7 @@ namespace
37143731
tensors_lhs_Allocs /* output */,
37153732
tensors_rhs_Allocs /* output */,
37163733
allFormats /* output */,
3734+
allBlocks /* output */,
37173735
allPerms /* output */,
37183736
allPerms_rhs /* output */,
37193737
main_tensors_all /* output */,
@@ -3743,6 +3761,7 @@ namespace
37433761
main_tensor_nums,
37443762
allPerms,
37453763
allFormats,
3764+
allBlocks,
37463765
main_tensors_all,
37473766
nested_forops,
37483767
nested_AccessIdx,
@@ -3776,6 +3795,7 @@ namespace
37763795
main_tensor_nums,
37773796
allPerms,
37783797
allFormats,
3798+
allBlocks,
37793799
main_tensors_all,
37803800
symbolic_nested_forops,
37813801
symbolic_nested_AccessIdx,
@@ -4390,6 +4410,7 @@ void LowerIndexTreeToSCFPass::doLoweringIndexTreeToSCF(indexTree::IndexTreeOp &r
43904410
std::vector<Value> tensors;
43914411
std::vector<unsigned int> ids;
43924412
std::vector<std::string> formats;
4413+
std::vector<std::string> blocks;
43934414

43944415
comet_vdump(cur_op);
43954416

@@ -4398,17 +4419,29 @@ void LowerIndexTreeToSCFPass::doLoweringIndexTreeToSCF(indexTree::IndexTreeOp &r
43984419
leafs,
43994420
tensors /* output */,
44004421
ids /* output */,
4401-
formats /* output */);
4422+
formats /* output */,
4423+
blocks /* output */);
44024424

44034425
comet_debug() << " indices.size(): " << indices.size() << " tensors.size(): " << tensors.size() << "\n";
44044426
for (unsigned int m = 0; m < tensors.size(); m++)
44054427
{
4406-
comet_debug() << " Formats:" << formats[m] << " " << ids[m] << " ";
4428+
comet_debug() << " Formats:" << formats[m] << " " << ids[m] << " \n";
4429+
comet_debug() << " Blocks:" << blocks[m] << " " << ids[m] << " \n";
44074430
comet_vdump(tensors[m]);
4431+
comet_debug() << "\n";
44084432
}
4433+
comet_debug() << "---------------\n";
4434+
4435+
//debug
4436+
//for (auto fmt : formats) {
4437+
// std::cout << "FMT: " << fmt << std::endl;
4438+
//}
4439+
//for (auto block : blocks) {
4440+
// std::cout << "BLOCK: " << block << std::endl;
4441+
//}
44094442

44104443
comet_debug() << " call genForOps, i = " << i << "\n";
4411-
genForOps(tensors, ids, formats, rootOp, builder, opstree_vec[i], symbolicInfo);
4444+
genForOps(tensors, ids, formats, blocks, rootOp, builder, opstree_vec[i], symbolicInfo);
44124445
{
44134446
comet_pdump(rootOp->getParentOfType<ModuleOp>());
44144447
}

0 commit comments

Comments
 (0)