@@ -441,8 +441,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
441
441
mlir::indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeLHSOp>(
442
442
lhs_op);
443
443
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats ();
444
+ ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks ();
444
445
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms ();
445
446
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector (op_formats_ArrayAttr);
447
+ std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector (op_blocks_ArrayAttr);
446
448
std::vector<std::vector<int >> old_perms_ints = convertArrayAttrIntTo2DVector (op_perms_ArrayAttr);
447
449
448
450
// / Create the new formats
@@ -451,6 +453,13 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
451
453
SmallVector<StringRef, 8 > formats;
452
454
formats.insert (formats.end (), old_formats_strs[0 ].begin () + rank_base, old_formats_strs[0 ].end ());
453
455
new_formats.push_back (builder.getStrArrayAttr (formats));
456
+
457
+ // / Create the new blocks
458
+ // / i.g., convert [["D", "D"]] to [["D"]]
459
+ SmallVector<Attribute, 8 > new_blocks;
460
+ SmallVector<StringRef, 8 > blocks;
461
+ blocks.insert (blocks.end (), old_blocks_strs[0 ].begin () + rank_base, old_blocks_strs[0 ].end ());
462
+ new_blocks.push_back (builder.getStrArrayAttr (blocks));
454
463
455
464
// / Create the new perms
456
465
// / i.g., convert [[1, 0]] to [[0]]
@@ -467,7 +476,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
467
476
mlir::UnrankedTensorType::get (builder.getF64Type ()),
468
477
tensors,
469
478
builder.getArrayAttr (new_perms),
470
- builder.getArrayAttr (new_formats));
479
+ builder.getArrayAttr (new_formats),
480
+ builder.getArrayAttr (new_blocks));
471
481
472
482
return new_lhs_op;
473
483
}
@@ -486,8 +496,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
486
496
mlir::indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeRHSOp>(
487
497
rhs_op);
488
498
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats ();
499
+ ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks ();
489
500
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms ();
490
501
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector (op_formats_ArrayAttr);
502
+ std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector (op_blocks_ArrayAttr);
491
503
std::vector<std::vector<int >> old_perms_ints = convertArrayAttrIntTo2DVector (op_perms_ArrayAttr);
492
504
493
505
// / Locate the operand to be reduced
@@ -517,6 +529,23 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
517
529
}
518
530
new_formats.push_back (builder.getStrArrayAttr (formats));
519
531
}
532
+
533
+ // / Create the new blocks
534
+ // / Basically same algorithm as the formats
535
+ SmallVector<Attribute, 8 > new_blocks;
536
+ for (uint32_t b_i = 0 ; b_i < old_blocks_strs.size (); ++b_i)
537
+ {
538
+ SmallVector<StringRef, 8 > blocks;
539
+ if (b_i == tensor_id)
540
+ { // / for the new reduced tensor
541
+ blocks.insert (blocks.end (), old_formats_strs[b_i].begin () + rank_base, old_blocks_strs[b_i].end ());
542
+ }
543
+ else
544
+ { // / for other remaining old operands
545
+ blocks.insert (blocks.end (), old_blocks_strs[b_i].begin (), old_blocks_strs[b_i].end ());
546
+ }
547
+ new_blocks.push_back (builder.getStrArrayAttr (blocks));
548
+ }
520
549
521
550
// / Create the new perms
522
551
// / i.g., convert [[1, 0], [0, 2]] to [[0], [0, 2]]
@@ -555,7 +584,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
555
584
mlir::UnrankedTensorType::get (builder.getF64Type ()),
556
585
tensors,
557
586
builder.getArrayAttr (new_perms),
558
- builder.getArrayAttr (new_formats));
587
+ builder.getArrayAttr (new_formats),
588
+ builder.getArrayAttr (new_blocks));
559
589
560
590
return new_rhs_op;
561
591
}
@@ -730,14 +760,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeRHS(
730
760
SmallVector<Attribute, 1 > formats_rhs;
731
761
SmallVector<StringRef, 1 > empty_format;
732
762
formats_rhs.push_back (builder.getStrArrayAttr (empty_format));
763
+
764
+ SmallVector<Attribute, 1 > blocks_rhs;
765
+ SmallVector<StringRef, 1 > empty_block;
766
+ blocks_rhs.push_back (builder.getStrArrayAttr (empty_block));
733
767
734
768
// / TODO(zpeng): What if the type is not F64?
735
769
mlir::Value compute_rhs = builder.create <indexTree::IndexTreeComputeRHSOp>(
736
770
loc,
737
771
mlir::UnrankedTensorType::get (builder.getF64Type ()),
738
772
tensors_rhs,
739
773
builder.getArrayAttr (indices_rhs),
740
- builder.getArrayAttr (formats_rhs));
774
+ builder.getArrayAttr (formats_rhs),
775
+ builder.getArrayAttr (blocks_rhs));
741
776
742
777
return compute_rhs;
743
778
}
@@ -771,13 +806,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeLHS(
771
806
SmallVector<Attribute, 1 > formats_lhs;
772
807
SmallVector<StringRef, 1 > one_format (rank, " D" );
773
808
formats_lhs.push_back (builder.getStrArrayAttr (one_format));
809
+
810
+ // / Get blocks [["UNK"]]
811
+ SmallVector<Attribute, 1 > blocks_lhs;
812
+ SmallVector<StringRef, 1 > one_block (rank, " UNK" );
813
+ blocks_lhs.push_back (builder.getStrArrayAttr (one_block));
774
814
775
815
mlir::Value compute_lhs = builder.create <indexTree::IndexTreeComputeLHSOp>(
776
816
loc,
777
817
mlir::UnrankedTensorType::get (builder.getF64Type ()),
778
818
tensors_lhs,
779
819
builder.getArrayAttr (indices_lhs),
780
- builder.getArrayAttr (formats_lhs));
820
+ builder.getArrayAttr (formats_lhs),
821
+ builder.getArrayAttr (blocks_lhs));
781
822
782
823
return compute_lhs;
783
824
}
0 commit comments