Skip to content

Commit 45ac02b

Browse files
pflynn157pthomadakis
authored andcommitted
Began adding allBlocks attribute to index tree #20
1 parent 672884e commit 45ac02b

File tree

13 files changed

+843
-356
lines changed

13 files changed

+843
-356
lines changed

first.mlir

+542-312
Large diffs are not rendered by default.

first.ta

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
def main() {
22
#IndexLabel Declarations
33
IndexLabel [a] = [?];
4-
IndexLabel [b] = [?];
5-
6-
#Tensor Declarations
7-
Tensor<double> A([a, b], {BCSR});
4+
IndexLabel [b] = [?];
5+
6+
Tensor<double> A([a, b], {CSR});
7+
Tensor<double> B([b, a], {Dense});
8+
Tensor<double> C([b, a], {Dense});
9+
810
A[a, b] = comet_read(0);
11+
B[b, a] = 1.0;
12+
C[b, a] = A[a, b] * B[b, a];
13+
#C[b, a] = 1.0;
14+
915
print(A);
16+
print(B);
17+
print(C);
1018
}
1119

include/comet/Dialect/IndexTree/IR/IndexTree.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class Index_Tree
214214

215215
public:
216216
IndicesType getIndices(std::vector<mlir::Value> &lbls);
217-
Tensor *getOrCreateTensor(mlir::Value v, std::vector<mlir::Value> &allIndexLabels, FormatsType &formats);
217+
Tensor *getOrCreateTensor(mlir::Value v, std::vector<mlir::Value> &allIndexLabels, FormatsType &formats, BlocksType &blocks);
218218

219219
vector<TreeNode *> getNodes();
220220

include/comet/Dialect/IndexTree/IR/IndexTreeOps.td

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def IndexTreeComputeLHSOp : IndexTree_Op<"ComputeLHS", [Pure]>{
6464
let summary = "";
6565
let description = [{}];
6666

67-
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats);
67+
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats, ArrayAttr:$allBlocks);
6868
let results = (outs AnyType:$output);
6969
}
7070

7171
def IndexTreeComputeRHSOp : IndexTree_Op<"ComputeRHS", [Pure]>{
7272
let summary = "";
7373
let description = [{}];
7474

75-
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats);
75+
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats, ArrayAttr:$allBlocks);
7676
let results = (outs AnyType:$output);
7777
}
7878

@@ -118,4 +118,4 @@ def IndexTreeOp : IndexTree_Op<"itree", [Pure]>{
118118
//let hasVerifier = 1;
119119
}
120120

121-
#endif // INDEXTREE_OPS
121+
#endif // INDEXTREE_OPS

include/comet/Dialect/IndexTree/Transforms/Tensor.h

+19-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
typedef std::vector<unsigned int> IndicesType;
3737
typedef std::vector<std::string> FormatsType;
38+
typedef std::vector<std::string> BlocksType;
3839
typedef std::vector<IterDomain *> DomainsType;
3940

4041
using std::make_shared;
@@ -53,6 +54,7 @@ class Tensor
5354
string name;
5455
IndicesType indices;
5556
FormatsType format;
57+
BlocksType block;
5658
IndicesType hidden;
5759
DomainsType domains;
5860
UnitExpression *definingExpr = nullptr;
@@ -61,12 +63,13 @@ class Tensor
6163
public:
6264
static int count;
6365

64-
Tensor(mlir::Value &value, IndicesType &indices, vector<string> &format)
66+
Tensor(mlir::Value &value, IndicesType &indices, vector<string> &format, vector<string> &block)
6567
{
6668
// assert(format.size() == indices.size());
6769
this->value = value;
6870
this->indices = indices;
6971
this->format = format;
72+
this->block = block;
7073
id = count++;
7174
domains = std::vector<IterDomain *>(indices.size());
7275
}
@@ -154,6 +157,21 @@ class Tensor
154157
{
155158
Tensor::format = format;
156159
}
160+
161+
const string &getBlock(int i) const
162+
{
163+
return block.at(i);
164+
}
165+
166+
BlocksType &getBlocks()
167+
{
168+
return block;
169+
}
170+
171+
void setBlock(const vector<string> &block)
172+
{
173+
Tensor::block = block;
174+
}
157175

158176
void setHiddenIndices(IndicesType &hiddenIndices)
159177
{

include/comet/Dialect/Utils/Utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ namespace mlir
126126
std::vector<unsigned> getSumIndices(std::vector<unsigned> rhs_perm, std::vector<unsigned> rhs_perm_free);
127127
std::vector<unsigned> getIndexIterateOrder(std::vector<unsigned> rhs1_perm, std::vector<unsigned> rhs2_perm);
128128
std::vector<std::vector<std::string>> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector<std::vector<int64_t>> allPerms);
129+
std::vector<std::vector<std::string>> getAllBlocks(ArrayAttr opFormatsArrayAttr, std::vector<std::vector<int64_t>> allPerms);
129130
bool checkIsElementwise(std::vector<std::vector<int>> allPerms);
130131
bool checkIsMixedMode(std::vector<std::vector<std::string>> formats);
131132
bool checkIsDense(std::vector<std::string> format);

lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp

+37-9
Original file line numberDiff line numberDiff line change
@@ -264,24 +264,28 @@ void doTensorMultOp(TensorMultOp op, unique_ptr<Index_Tree> &tree, TargetDevice
264264
// comet_debug() << allPerms;
265265
#endif
266266

267+
268+
auto allBlocks = getAllBlocks(op.getFormatsAttr(), allPerms);
267269
auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms);
268270
auto SemiringOp = op.getSemiringAttr();
269271
auto MaskingTypeAttr = op.getMaskTypeAttr();
270272

271273
/// If the operation is one of the chosen operations, then record output indices as parallel interators.
272274
bool is_chosen_operations = check_chosen_operations(allPerms, allFormats);
273275

274-
auto B = tree->getOrCreateTensor(rhs1_tensor, rhs1_labels, allFormats[0]);
275-
auto C = tree->getOrCreateTensor(rhs2_tensor, rhs2_labels, allFormats[1]);
276-
auto A = tree->getOrCreateTensor(lhs_tensor, lhs_labels, allFormats[2]);
276+
277+
278+
auto B = tree->getOrCreateTensor(rhs1_tensor, rhs1_labels, allFormats[0], allBlocks[0]);
279+
auto C = tree->getOrCreateTensor(rhs2_tensor, rhs2_labels, allFormats[1], allBlocks[1]);
280+
auto A = tree->getOrCreateTensor(lhs_tensor, lhs_labels, allFormats[2], allBlocks[2]);
277281

278282
Tensor *M;
279283
std::unique_ptr<UnitExpression> e;
280284
std::vector<mlir::Value> empty;
281285
if (mask_tensor != nullptr) /// mask is an optional input
282286
{
283287
comet_debug() << "mask input provided by user\n";
284-
M = tree->getOrCreateTensor(mask_tensor, empty, allFormats[2]); /// We don't need indexlabel info for the mask
288+
M = tree->getOrCreateTensor(mask_tensor, empty, allFormats[2], allBlocks[2]); /// We don't need indexlabel info for the mask
285289
e = make_unique<UnitExpression>(A, B, C, M, "*");
286290
}
287291
else
@@ -377,14 +381,15 @@ void doElementWiseOp(T op, unique_ptr<Index_Tree> &tree)
377381

378382
auto allPerms = getAllPerms(op.getIndexingMaps());
379383
auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms);
384+
auto allBlocks = getAllBlocks(op.getFormatsAttr(), allPerms);
380385
auto SemiringOp = op.getSemiringAttr();
381386
auto maskAttr = "none";
382387

383388
assert(allPerms.size() == 3);
384389

385-
auto B = tree->getOrCreateTensor(rhs1_tensor, rhs1_labels, allFormats[0]);
386-
auto C = tree->getOrCreateTensor(rhs2_tensor, rhs2_labels, allFormats[1]);
387-
auto A = tree->getOrCreateTensor(lhs_tensor, lhs_labels, allFormats[2]);
390+
auto B = tree->getOrCreateTensor(rhs1_tensor, rhs1_labels, allFormats[0], allBlocks[0]);
391+
auto C = tree->getOrCreateTensor(rhs2_tensor, rhs2_labels, allFormats[1], allBlocks[1]);
392+
auto A = tree->getOrCreateTensor(lhs_tensor, lhs_labels, allFormats[2], allBlocks[2]);
388393

389394
auto e = make_unique<UnitExpression>(A, B, C, "*");
390395

@@ -498,6 +503,27 @@ IndexTreeComputeOp createComputeNodeOp(OpBuilder &builder, TreeNode *node, Locat
498503
}
499504
allFormats_lhs.push_back(builder.getStrArrayAttr(formats));
500505
}
506+
507+
SmallVector<Attribute, 8> allBlocks_rhs;
508+
for (auto t : expr->getOperands())
509+
{
510+
SmallVector<StringRef, 8> blocks;
511+
for (auto &b : t->getBlocks())
512+
{
513+
blocks.push_back(b);
514+
}
515+
allBlocks_rhs.push_back(builder.getStrArrayAttr(blocks));
516+
}
517+
SmallVector<Attribute, 8> allBlocks_lhs;
518+
for (auto t : expr->getResults())
519+
{
520+
SmallVector<StringRef, 8> blocks;
521+
for (auto &b : t->getBlocks())
522+
{
523+
blocks.push_back(b);
524+
}
525+
allBlocks_lhs.push_back(builder.getStrArrayAttr(blocks));
526+
}
501527

502528
std::vector<Value> t_rhs;
503529
Value t_lhs = expr->getLHS()->getValue();
@@ -516,12 +542,14 @@ IndexTreeComputeOp createComputeNodeOp(OpBuilder &builder, TreeNode *node, Locat
516542
Value leafop_rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(loc,
517543
mlir::UnrankedTensorType::get(builder.getF64Type()), t_rhs,
518544
builder.getArrayAttr(allIndices_rhs),
519-
builder.getArrayAttr(allFormats_rhs));
545+
builder.getArrayAttr(allFormats_rhs),
546+
builder.getArrayAttr(allBlocks_rhs));
520547
comet_vdump(leafop_rhs);
521548
Value leafop_lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(loc,
522549
mlir::UnrankedTensorType::get(builder.getF64Type()), t_lhs,
523550
builder.getArrayAttr(allIndices_lhs),
524-
builder.getArrayAttr(allFormats_lhs));
551+
builder.getArrayAttr(allFormats_lhs),
552+
builder.getArrayAttr(allBlocks_lhs));
525553
comet_vdump(leafop_lhs);
526554

527555
bool comp_worksp_opt = false; /// non-compressed workspace, this is a place-holder and it is updated in workspace transform pass.

lib/Conversion/TensorAlgebraToSCF/LowerPCToLoops.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -140,23 +140,25 @@ void PCToLoopsLoweringPass::replicateOpsForLoopBody(Location loc, OpBuilder &bui
140140
{
141141
indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<indexTree::IndexTreeComputeRHSOp>(op);
142142
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats();
143+
ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks();
143144
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms();
144145

145146
rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(loc, mlir::UnrankedTensorType::get(builder.getF64Type()),
146147
it_compute_rhs_op->getOperands(), // tensors
147-
op_perms_ArrayAttr, op_formats_ArrayAttr);
148+
op_perms_ArrayAttr, op_formats_ArrayAttr, op_blocks_ArrayAttr);
148149
}
149150

150151
// create IndexTreeComputeLHSOp, no dependency to earlier replications
151152
if (isa<indexTree::IndexTreeComputeLHSOp>(op))
152153
{
153154
indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<indexTree::IndexTreeComputeLHSOp>(op);
154155
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats();
156+
ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks();
155157
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms();
156158

157159
lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(loc, mlir::UnrankedTensorType::get(builder.getF64Type()),
158160
it_compute_lhs_op->getOperands(), /// tensors
159-
op_perms_ArrayAttr, op_formats_ArrayAttr);
161+
op_perms_ArrayAttr, op_formats_ArrayAttr, op_blocks_ArrayAttr);
160162
}
161163

162164
/// create IndexTreeComputeOp only if rhs and lhs are ready
@@ -338,4 +340,4 @@ void PCToLoopsLoweringPass::runOnOperation()
338340
std::unique_ptr<Pass> mlir::comet::createPCToLoopsLoweringPass()
339341
{
340342
return std::make_unique<PCToLoopsLoweringPass>();
341-
}
343+
}

lib/Dialect/IndexTree/IR/IndexTree.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,12 @@ IndicesType Index_Tree::getIndices(std::vector<mlir::Value> &lbls)
117117
return indices;
118118
}
119119

120-
Tensor *Index_Tree::getOrCreateTensor(mlir::Value v, std::vector<mlir::Value> &allIndexLabels, FormatsType &formats)
120+
Tensor *Index_Tree::getOrCreateTensor(mlir::Value v, std::vector<mlir::Value> &allIndexLabels, FormatsType &formats, BlocksType &blocks)
121121
{
122122
IndicesType indices = getIndices(allIndexLabels);
123123
comet_debug() << "Num Indices: " << indices.size() << ", Num formats " << formats.size() << "\n";
124124

125-
return new Tensor(v, indices, formats);
125+
return new Tensor(v, indices, formats, blocks);
126126
}
127127

128128
TreeNode *Index_Tree::addComputeNode(unique_ptr<UnitExpression> expr, TreeNode *parent)
@@ -261,4 +261,4 @@ void IteratorType::setType(std::string t) {
261261
} else {
262262
llvm::errs() << "Unsupported iterator type " + t + "\n";
263263
}
264-
}
264+
}

lib/Dialect/IndexTree/Transforms/Fusion.cpp

+45-4
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
440440
mlir::indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeLHSOp>(
441441
lhs_op);
442442
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats();
443+
ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks();
443444
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms();
444445
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector(op_formats_ArrayAttr);
446+
std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector(op_blocks_ArrayAttr);
445447
std::vector<std::vector<int>> old_perms_ints = convertArrayAttrIntTo2DVector(op_perms_ArrayAttr);
446448

447449
/// Create the new formats
@@ -450,6 +452,13 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
450452
SmallVector<StringRef, 8> formats;
451453
formats.insert(formats.end(), old_formats_strs[0].begin() + rank_base, old_formats_strs[0].end());
452454
new_formats.push_back(builder.getStrArrayAttr(formats));
455+
456+
/// Create the new blocks
457+
/// i.g., convert [["D", "D"]] to [["D"]]
458+
SmallVector<Attribute, 8> new_blocks;
459+
SmallVector<StringRef, 8> blocks;
460+
blocks.insert(blocks.end(), old_blocks_strs[0].begin() + rank_base, old_blocks_strs[0].end());
461+
new_blocks.push_back(builder.getStrArrayAttr(blocks));
453462

454463
/// Create the new perms
455464
/// i.g., convert [[1, 0]] to [[0]]
@@ -466,7 +475,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
466475
mlir::UnrankedTensorType::get(builder.getF64Type()),
467476
tensors,
468477
builder.getArrayAttr(new_perms),
469-
builder.getArrayAttr(new_formats));
478+
builder.getArrayAttr(new_formats),
479+
builder.getArrayAttr(new_blocks));
470480

471481
return new_lhs_op;
472482
}
@@ -485,8 +495,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
485495
mlir::indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeRHSOp>(
486496
rhs_op);
487497
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats();
498+
ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks();
488499
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms();
489500
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector(op_formats_ArrayAttr);
501+
std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector(op_blocks_ArrayAttr);
490502
std::vector<std::vector<int>> old_perms_ints = convertArrayAttrIntTo2DVector(op_perms_ArrayAttr);
491503

492504
/// Locate the operand to be reduced
@@ -516,6 +528,23 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
516528
}
517529
new_formats.push_back(builder.getStrArrayAttr(formats));
518530
}
531+
532+
/// Create the new blocks
533+
/// Basically same algorithm as the formats
534+
SmallVector<Attribute, 8> new_blocks;
535+
for (uint32_t b_i = 0; b_i < old_blocks_strs.size(); ++b_i)
536+
{
537+
SmallVector<StringRef, 8> blocks;
538+
if (b_i == tensor_id)
539+
{ /// for the new reduced tensor
540+
blocks.insert(blocks.end(), old_formats_strs[b_i].begin() + rank_base, old_blocks_strs[b_i].end());
541+
}
542+
else
543+
{ /// for other remaining old operands
544+
blocks.insert(blocks.end(), old_blocks_strs[b_i].begin(), old_blocks_strs[b_i].end());
545+
}
546+
new_blocks.push_back(builder.getStrArrayAttr(blocks));
547+
}
519548

520549
/// Create the new perms
521550
/// i.g., convert [[1, 0], [0, 2]] to [[0], [0, 2]]
@@ -554,7 +583,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
554583
mlir::UnrankedTensorType::get(builder.getF64Type()),
555584
tensors,
556585
builder.getArrayAttr(new_perms),
557-
builder.getArrayAttr(new_formats));
586+
builder.getArrayAttr(new_formats),
587+
builder.getArrayAttr(new_blocks));
558588

559589
return new_rhs_op;
560590
}
@@ -729,14 +759,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeRHS(
729759
SmallVector<Attribute, 1> formats_rhs;
730760
SmallVector<StringRef, 1> empty_format;
731761
formats_rhs.push_back(builder.getStrArrayAttr(empty_format));
762+
763+
SmallVector<Attribute, 1> blocks_rhs;
764+
SmallVector<StringRef, 1> empty_block;
765+
blocks_rhs.push_back(builder.getStrArrayAttr(empty_block));
732766

733767
/// TODO(zpeng): What if the type is not F64?
734768
mlir::Value compute_rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(
735769
loc,
736770
mlir::UnrankedTensorType::get(builder.getF64Type()),
737771
tensors_rhs,
738772
builder.getArrayAttr(indices_rhs),
739-
builder.getArrayAttr(formats_rhs));
773+
builder.getArrayAttr(formats_rhs),
774+
builder.getArrayAttr(blocks_rhs));
740775

741776
return compute_rhs;
742777
}
@@ -770,13 +805,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeLHS(
770805
SmallVector<Attribute, 1> formats_lhs;
771806
SmallVector<StringRef, 1> one_format(rank, "D");
772807
formats_lhs.push_back(builder.getStrArrayAttr(one_format));
808+
809+
/// Get blocks [["UNK"]]
810+
SmallVector<Attribute, 1> blocks_lhs;
811+
SmallVector<StringRef, 1> one_block(rank, "UNK");
812+
blocks_lhs.push_back(builder.getStrArrayAttr(one_block));
773813

774814
mlir::Value compute_lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(
775815
loc,
776816
mlir::UnrankedTensorType::get(builder.getF64Type()),
777817
tensors_lhs,
778818
builder.getArrayAttr(indices_lhs),
779-
builder.getArrayAttr(formats_lhs));
819+
builder.getArrayAttr(formats_lhs),
820+
builder.getArrayAttr(blocks_lhs));
780821

781822
return compute_lhs;
782823
}

0 commit comments

Comments
 (0)