Skip to content

Commit 065f76d

Browse files
committed
Began adding allBlocks attribute to index tree #20
1 parent cdbacd8 commit 065f76d

File tree

14 files changed

+841
-357
lines changed

14 files changed

+841
-357
lines changed

first.mlir

+542-312
Large diffs are not rendered by default.

first.ta

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
def main() {
22
#IndexLabel Declarations
33
IndexLabel [a] = [?];
4-
IndexLabel [b] = [?];
5-
6-
#Tensor Declarations
7-
Tensor<double> A([a, b], {BCSR});
4+
IndexLabel [b] = [?];
5+
6+
Tensor<double> A([a, b], {CSR});
7+
Tensor<double> B([b, a], {Dense});
8+
Tensor<double> C([b, a], {Dense});
9+
810
A[a, b] = comet_read(0);
11+
B[b, a] = 1.0;
12+
C[b, a] = A[a, b] * B[b, a];
13+
#C[b, a] = 1.0;
14+
915
print(A);
16+
print(B);
17+
print(C);
1018
}
1119

include/comet/Dialect/IndexTree/IR/IndexTree.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class Index_Tree
198198
unsigned int indexID = 0;
199199

200200
public:
201-
Tensor *getOrCreateTensor(mlir::Value v, FormatsType &formats);
201+
Tensor *getOrCreateTensor(mlir::Value v, FormatsType &formats, BlocksType &blocks);
202202
IndicesType getIndices(mlir::Value v);
203203

204204
vector<TreeNode *> getNodes();

include/comet/Dialect/IndexTree/IR/IndexTreeOps.td

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def IndexTreeComputeLHSOp : IndexTree_Op<"ComputeLHS", [Pure]>{
6464
let summary = "";
6565
let description = [{}];
6666

67-
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats);
67+
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats, ArrayAttr:$allBlocks);
6868
let results = (outs AnyType:$output);
6969
}
7070

7171
def IndexTreeComputeRHSOp : IndexTree_Op<"ComputeRHS", [Pure]>{
7272
let summary = "";
7373
let description = [{}];
7474

75-
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats);
75+
let arguments = (ins Variadic<AnyType>:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats, ArrayAttr:$allBlocks);
7676
let results = (outs AnyType:$output);
7777
}
7878

@@ -116,4 +116,4 @@ def IndexTreeOp : IndexTree_Op<"itree", [Pure]>{
116116
//let hasVerifier = 1;
117117
}
118118

119-
#endif // INDEXTREE_OPS
119+
#endif // INDEXTREE_OPS

include/comet/Dialect/IndexTree/Transforms/Tensor.h

+19-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
typedef std::vector<unsigned int> IndicesType;
3737
typedef std::vector<std::string> FormatsType;
38+
typedef std::vector<std::string> BlocksType;
3839
typedef std::vector<IterDomain *> DomainsType;
3940

4041
using std::make_shared;
@@ -53,6 +54,7 @@ class Tensor
5354
string name;
5455
IndicesType indices;
5556
FormatsType format;
57+
BlocksType block;
5658
IndicesType hidden;
5759
DomainsType domains;
5860
UnitExpression *definingExpr = nullptr;
@@ -61,12 +63,13 @@ class Tensor
6163
public:
6264
static int count;
6365

64-
Tensor(mlir::Value &value, IndicesType &indices, vector<string> &format)
66+
Tensor(mlir::Value &value, IndicesType &indices, vector<string> &format, vector<string> &block)
6567
{
6668
assert(format.size() == indices.size());
6769
this->value = value;
6870
this->indices = indices;
6971
this->format = format;
72+
this->block = block;
7073
id = count++;
7174
domains = std::vector<IterDomain *>(indices.size());
7275
}
@@ -154,6 +157,21 @@ class Tensor
154157
{
155158
Tensor::format = format;
156159
}
160+
161+
const string &getBlock(int i) const
162+
{
163+
return block.at(i);
164+
}
165+
166+
BlocksType &getBlocks()
167+
{
168+
return block;
169+
}
170+
171+
void setBlock(const vector<string> &block)
172+
{
173+
Tensor::block = block;
174+
}
157175

158176
void setHiddenIndices(IndicesType &hiddenIndices)
159177
{

include/comet/Dialect/Utils/Utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ namespace mlir
122122
std::vector<unsigned> getSumIndices(std::vector<unsigned> rhs_perm, std::vector<unsigned> rhs_perm_free);
123123
std::vector<unsigned> getIndexIterateOrder(std::vector<unsigned> rhs1_perm, std::vector<unsigned> rhs2_perm);
124124
std::vector<std::vector<std::string>> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector<std::vector<int64_t>> allPerms);
125+
std::vector<std::vector<std::string>> getAllBlocks(ArrayAttr opFormatsArrayAttr, std::vector<std::vector<int64_t>> allPerms);
125126
bool checkIsElementwise(std::vector<std::vector<int>> allPerms);
126127
bool checkIsMixedMode(std::vector<std::vector<std::string>> formats);
127128
bool checkIsDense(std::vector<std::string> format);

lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp

+34-9
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,22 @@ void doTensorMultOp(TensorMultOp op, unique_ptr<Index_Tree> &tree)
149149
comet_vdump(mask_tensor);
150150

151151
auto allPerms = getAllPerms(op.getIndexingMaps());
152+
auto allBlocks = getAllBlocks(op.getFormatsAttr(), allPerms);
152153
auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms);
153154
auto SemiringOp = op.getSemiringAttr();
154155
auto MaskingTypeAttr = op.getMaskTypeAttr();
155156

156157
assert(allPerms.size() == 3);
157158

158-
auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0]);
159-
auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1]);
160-
auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2]);
159+
auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0], allBlocks[0]);
160+
auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1], allBlocks[1]);
161+
auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2], allBlocks[2]);
161162
Tensor *M;
162163
std::unique_ptr<UnitExpression> e;
163164
if (mask_tensor != nullptr) /// mask is an optional input
164165
{
165166
comet_debug() << "mask input provided by user\n";
166-
M = tree->getOrCreateTensor(mask_tensor, allFormats[2]); /// format same as lhs_tensor
167+
M = tree->getOrCreateTensor(mask_tensor, allFormats[2], allBlocks[2]); /// format same as lhs_tensor
167168
e = make_unique<UnitExpression>(A, B, C, M, "*");
168169
}
169170
else
@@ -225,14 +226,15 @@ void doElementWiseOp(T op, unique_ptr<Index_Tree> &tree)
225226

226227
auto allPerms = getAllPerms(op.getIndexingMaps());
227228
auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms);
229+
auto allBlocks = getAllBlocks(op.getFormatsAttr(), allPerms);
228230
auto SemiringOp = op.getSemiringAttr();
229231
auto maskAttr = "none";
230232

231233
assert(allPerms.size() == 3);
232234

233-
auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0]);
234-
auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1]);
235-
auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2]);
235+
auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0], allBlocks[0]);
236+
auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1], allBlocks[1]);
237+
auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2], allBlocks[2]);
236238

237239
auto e = make_unique<UnitExpression>(A, B, C, "*");
238240

@@ -335,6 +337,27 @@ IndexTreeComputeOp createComputeNodeOp(OpBuilder &builder, TreeNode *node, Locat
335337
}
336338
allFormats_lhs.push_back(builder.getStrArrayAttr(formats));
337339
}
340+
341+
SmallVector<Attribute, 8> allBlocks_rhs;
342+
for (auto t : expr->getOperands())
343+
{
344+
SmallVector<StringRef, 8> blocks;
345+
for (auto &b : t->getBlocks())
346+
{
347+
blocks.push_back(b);
348+
}
349+
allBlocks_rhs.push_back(builder.getStrArrayAttr(blocks));
350+
}
351+
SmallVector<Attribute, 8> allBlocks_lhs;
352+
for (auto t : expr->getResults())
353+
{
354+
SmallVector<StringRef, 8> blocks;
355+
for (auto &b : t->getBlocks())
356+
{
357+
blocks.push_back(b);
358+
}
359+
allBlocks_lhs.push_back(builder.getStrArrayAttr(blocks));
360+
}
338361

339362
std::vector<Value> t_rhs;
340363
Value t_lhs = expr->getLHS()->getValue();
@@ -353,12 +376,14 @@ IndexTreeComputeOp createComputeNodeOp(OpBuilder &builder, TreeNode *node, Locat
353376
Value leafop_rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(loc,
354377
mlir::UnrankedTensorType::get(builder.getF64Type()), t_rhs,
355378
builder.getArrayAttr(allIndices_rhs),
356-
builder.getArrayAttr(allFormats_rhs));
379+
builder.getArrayAttr(allFormats_rhs),
380+
builder.getArrayAttr(allBlocks_rhs));
357381
comet_vdump(leafop_rhs);
358382
Value leafop_lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(loc,
359383
mlir::UnrankedTensorType::get(builder.getF64Type()), t_lhs,
360384
builder.getArrayAttr(allIndices_lhs),
361-
builder.getArrayAttr(allFormats_lhs));
385+
builder.getArrayAttr(allFormats_lhs),
386+
builder.getArrayAttr(allBlocks_lhs));
362387
bool comp_worksp_opt = false; /// non-compressed workspace, this is a place-holder and it is updated in workspace transform pass.
363388
llvm::StringRef semiring = expr->getSemiring();
364389
llvm::StringRef maskType = expr->getMaskType();

lib/Conversion/TensorAlgebraToSCF/LowerPCToLoops.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -140,23 +140,25 @@ void PCToLoopsLoweringPass::replicateOpsForLoopBody(Location loc, OpBuilder &bui
140140
{
141141
indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<indexTree::IndexTreeComputeRHSOp>(op);
142142
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats();
143+
ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks();
143144
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms();
144145

145146
rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(loc, mlir::UnrankedTensorType::get(builder.getF64Type()),
146147
it_compute_rhs_op->getOperands(), // tensors
147-
op_perms_ArrayAttr, op_formats_ArrayAttr);
148+
op_perms_ArrayAttr, op_formats_ArrayAttr, op_blocks_ArrayAttr);
148149
}
149150

150151
// create IndexTreeComputeLHSOp, no dependency to earlier replications
151152
if (isa<indexTree::IndexTreeComputeLHSOp>(op))
152153
{
153154
indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<indexTree::IndexTreeComputeLHSOp>(op);
154155
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats();
156+
ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks();
155157
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms();
156158

157159
lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(loc, mlir::UnrankedTensorType::get(builder.getF64Type()),
158160
it_compute_lhs_op->getOperands(), /// tensors
159-
op_perms_ArrayAttr, op_formats_ArrayAttr);
161+
op_perms_ArrayAttr, op_formats_ArrayAttr, op_blocks_ArrayAttr);
160162
}
161163

162164
/// create IndexTreeComputeOp only if rhs and lhs are ready
@@ -338,4 +340,4 @@ void PCToLoopsLoweringPass::runOnOperation()
338340
std::unique_ptr<Pass> mlir::comet::createPCToLoopsLoweringPass()
339341
{
340342
return std::make_unique<PCToLoopsLoweringPass>();
341-
}
343+
}

lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using namespace mlir::bufferization;
4343
using namespace mlir::tensorAlgebra;
4444

4545
// *********** For debug purpose *********//
46-
//#define COMET_DEBUG_MODE
46+
// #define COMET_DEBUG_MODE
4747
#include "comet/Utils/debug.h"
4848
#undef COMET_DEBUG_MODE
4949
// *********** For debug purpose *********//

lib/Dialect/IndexTree/IR/IndexTree.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ IndicesType Index_Tree::getIndices(mlir::Value v)
121121
return indices;
122122
}
123123

124-
Tensor *Index_Tree::getOrCreateTensor(mlir::Value v, FormatsType &formats)
124+
Tensor *Index_Tree::getOrCreateTensor(mlir::Value v, FormatsType &formats, BlocksType &blocks)
125125
{
126126
IndicesType indices = getIndices(v);
127127
void *vp = v.getAsOpaquePointer();
128128
if (valueToTensor.count(vp) == 0)
129129
{
130-
valueToTensor[vp] = std::make_unique<Tensor>(v, indices, formats);
130+
valueToTensor[vp] = std::make_unique<Tensor>(v, indices, formats, blocks);
131131
}
132132
else
133133
{
@@ -268,4 +268,4 @@ vector<mlir::Operation *> Index_Tree::getContainingTAOps()
268268
ops.push_back(e->getOperation());
269269
}
270270
return ops;
271-
}
271+
}

lib/Dialect/IndexTree/Transforms/Fusion.cpp

+45-4
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
441441
mlir::indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeLHSOp>(
442442
lhs_op);
443443
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats();
444+
ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks();
444445
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms();
445446
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector(op_formats_ArrayAttr);
447+
std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector(op_blocks_ArrayAttr);
446448
std::vector<std::vector<int>> old_perms_ints = convertArrayAttrIntTo2DVector(op_perms_ArrayAttr);
447449

448450
/// Create the new formats
@@ -451,6 +453,13 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
451453
SmallVector<StringRef, 8> formats;
452454
formats.insert(formats.end(), old_formats_strs[0].begin() + rank_base, old_formats_strs[0].end());
453455
new_formats.push_back(builder.getStrArrayAttr(formats));
456+
457+
/// Create the new blocks
458+
/// i.g., convert [["D", "D"]] to [["D"]]
459+
SmallVector<Attribute, 8> new_blocks;
460+
SmallVector<StringRef, 8> blocks;
461+
blocks.insert(blocks.end(), old_blocks_strs[0].begin() + rank_base, old_blocks_strs[0].end());
462+
new_blocks.push_back(builder.getStrArrayAttr(blocks));
454463

455464
/// Create the new perms
456465
/// i.g., convert [[1, 0]] to [[0]]
@@ -467,7 +476,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
467476
mlir::UnrankedTensorType::get(builder.getF64Type()),
468477
tensors,
469478
builder.getArrayAttr(new_perms),
470-
builder.getArrayAttr(new_formats));
479+
builder.getArrayAttr(new_formats),
480+
builder.getArrayAttr(new_blocks));
471481

472482
return new_lhs_op;
473483
}
@@ -486,8 +496,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
486496
mlir::indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeRHSOp>(
487497
rhs_op);
488498
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats();
499+
ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks();
489500
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms();
490501
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector(op_formats_ArrayAttr);
502+
std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector(op_blocks_ArrayAttr);
491503
std::vector<std::vector<int>> old_perms_ints = convertArrayAttrIntTo2DVector(op_perms_ArrayAttr);
492504

493505
/// Locate the operand to be reduced
@@ -517,6 +529,23 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
517529
}
518530
new_formats.push_back(builder.getStrArrayAttr(formats));
519531
}
532+
533+
/// Create the new blocks
534+
/// Basically same algorithm as the formats
535+
SmallVector<Attribute, 8> new_blocks;
536+
for (uint32_t b_i = 0; b_i < old_blocks_strs.size(); ++b_i)
537+
{
538+
SmallVector<StringRef, 8> blocks;
539+
if (b_i == tensor_id)
540+
{ /// for the new reduced tensor
541+
blocks.insert(blocks.end(), old_formats_strs[b_i].begin() + rank_base, old_blocks_strs[b_i].end());
542+
}
543+
else
544+
{ /// for other remaining old operands
545+
blocks.insert(blocks.end(), old_blocks_strs[b_i].begin(), old_blocks_strs[b_i].end());
546+
}
547+
new_blocks.push_back(builder.getStrArrayAttr(blocks));
548+
}
520549

521550
/// Create the new perms
522551
/// i.g., convert [[1, 0], [0, 2]] to [[0], [0, 2]]
@@ -555,7 +584,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
555584
mlir::UnrankedTensorType::get(builder.getF64Type()),
556585
tensors,
557586
builder.getArrayAttr(new_perms),
558-
builder.getArrayAttr(new_formats));
587+
builder.getArrayAttr(new_formats),
588+
builder.getArrayAttr(new_blocks));
559589

560590
return new_rhs_op;
561591
}
@@ -730,14 +760,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeRHS(
730760
SmallVector<Attribute, 1> formats_rhs;
731761
SmallVector<StringRef, 1> empty_format;
732762
formats_rhs.push_back(builder.getStrArrayAttr(empty_format));
763+
764+
SmallVector<Attribute, 1> blocks_rhs;
765+
SmallVector<StringRef, 1> empty_block;
766+
blocks_rhs.push_back(builder.getStrArrayAttr(empty_block));
733767

734768
/// TODO(zpeng): What if the type is not F64?
735769
mlir::Value compute_rhs = builder.create<indexTree::IndexTreeComputeRHSOp>(
736770
loc,
737771
mlir::UnrankedTensorType::get(builder.getF64Type()),
738772
tensors_rhs,
739773
builder.getArrayAttr(indices_rhs),
740-
builder.getArrayAttr(formats_rhs));
774+
builder.getArrayAttr(formats_rhs),
775+
builder.getArrayAttr(blocks_rhs));
741776

742777
return compute_rhs;
743778
}
@@ -771,13 +806,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeLHS(
771806
SmallVector<Attribute, 1> formats_lhs;
772807
SmallVector<StringRef, 1> one_format(rank, "D");
773808
formats_lhs.push_back(builder.getStrArrayAttr(one_format));
809+
810+
/// Get blocks [["UNK"]]
811+
SmallVector<Attribute, 1> blocks_lhs;
812+
SmallVector<StringRef, 1> one_block(rank, "UNK");
813+
blocks_lhs.push_back(builder.getStrArrayAttr(one_block));
774814

775815
mlir::Value compute_lhs = builder.create<indexTree::IndexTreeComputeLHSOp>(
776816
loc,
777817
mlir::UnrankedTensorType::get(builder.getF64Type()),
778818
tensors_lhs,
779819
builder.getArrayAttr(indices_lhs),
780-
builder.getArrayAttr(formats_lhs));
820+
builder.getArrayAttr(formats_lhs),
821+
builder.getArrayAttr(blocks_lhs));
781822

782823
return compute_lhs;
783824
}

0 commit comments

Comments
 (0)