@@ -440,8 +440,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
440
440
mlir::indexTree::IndexTreeComputeLHSOp it_compute_lhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeLHSOp>(
441
441
lhs_op);
442
442
ArrayAttr op_formats_ArrayAttr = it_compute_lhs_op.getAllFormats ();
443
+ ArrayAttr op_blocks_ArrayAttr = it_compute_lhs_op.getAllBlocks ();
443
444
ArrayAttr op_perms_ArrayAttr = it_compute_lhs_op.getAllPerms ();
444
445
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector (op_formats_ArrayAttr);
446
+ std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector (op_blocks_ArrayAttr);
445
447
std::vector<std::vector<int >> old_perms_ints = convertArrayAttrIntTo2DVector (op_perms_ArrayAttr);
446
448
447
449
// / Create the new formats
@@ -450,6 +452,13 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
450
452
SmallVector<StringRef, 8 > formats;
451
453
formats.insert (formats.end (), old_formats_strs[0 ].begin () + rank_base, old_formats_strs[0 ].end ());
452
454
new_formats.push_back (builder.getStrArrayAttr (formats));
455
+
456
+ // / Create the new blocks
457
+ // / i.g., convert [["D", "D"]] to [["D"]]
458
+ SmallVector<Attribute, 8 > new_blocks;
459
+ SmallVector<StringRef, 8 > blocks;
460
+ blocks.insert (blocks.end (), old_blocks_strs[0 ].begin () + rank_base, old_blocks_strs[0 ].end ());
461
+ new_blocks.push_back (builder.getStrArrayAttr (blocks));
453
462
454
463
// / Create the new perms
455
464
// / i.g., convert [[1, 0]] to [[0]]
@@ -466,7 +475,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeLHS(
466
475
mlir::UnrankedTensorType::get (builder.getF64Type ()),
467
476
tensors,
468
477
builder.getArrayAttr (new_perms),
469
- builder.getArrayAttr (new_formats));
478
+ builder.getArrayAttr (new_formats),
479
+ builder.getArrayAttr (new_blocks));
470
480
471
481
return new_lhs_op;
472
482
}
@@ -485,8 +495,10 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
485
495
mlir::indexTree::IndexTreeComputeRHSOp it_compute_rhs_op = llvm::dyn_cast<mlir::indexTree::IndexTreeComputeRHSOp>(
486
496
rhs_op);
487
497
ArrayAttr op_formats_ArrayAttr = it_compute_rhs_op.getAllFormats ();
498
+ ArrayAttr op_blocks_ArrayAttr = it_compute_rhs_op.getAllBlocks ();
488
499
ArrayAttr op_perms_ArrayAttr = it_compute_rhs_op.getAllPerms ();
489
500
std::vector<std::vector<std::string>> old_formats_strs = convertArrayAttrStrTo2DVector (op_formats_ArrayAttr);
501
+ std::vector<std::vector<std::string>> old_blocks_strs = convertArrayAttrStrTo2DVector (op_blocks_ArrayAttr);
490
502
std::vector<std::vector<int >> old_perms_ints = convertArrayAttrIntTo2DVector (op_perms_ArrayAttr);
491
503
492
504
// / Locate the operand to be reduced
@@ -516,6 +528,23 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
516
528
}
517
529
new_formats.push_back (builder.getStrArrayAttr (formats));
518
530
}
531
+
532
+ // / Create the new blocks
533
+ // / Basically same algorithm as the formats
534
+ SmallVector<Attribute, 8 > new_blocks;
535
+ for (uint32_t b_i = 0 ; b_i < old_blocks_strs.size (); ++b_i)
536
+ {
537
+ SmallVector<StringRef, 8 > blocks;
538
+ if (b_i == tensor_id)
539
+ { // / for the new reduced tensor
540
+ blocks.insert (blocks.end (), old_formats_strs[b_i].begin () + rank_base, old_blocks_strs[b_i].end ());
541
+ }
542
+ else
543
+ { // / for other remaining old operands
544
+ blocks.insert (blocks.end (), old_blocks_strs[b_i].begin (), old_blocks_strs[b_i].end ());
545
+ }
546
+ new_blocks.push_back (builder.getStrArrayAttr (blocks));
547
+ }
519
548
520
549
// / Create the new perms
521
550
// / i.g., convert [[1, 0], [0, 2]] to [[0], [0, 2]]
@@ -554,7 +583,8 @@ mlir::Value IndexTreeKernelFusionPass::createReducedComputeRHS(
554
583
mlir::UnrankedTensorType::get (builder.getF64Type ()),
555
584
tensors,
556
585
builder.getArrayAttr (new_perms),
557
- builder.getArrayAttr (new_formats));
586
+ builder.getArrayAttr (new_formats),
587
+ builder.getArrayAttr (new_blocks));
558
588
559
589
return new_rhs_op;
560
590
}
@@ -729,14 +759,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeRHS(
729
759
SmallVector<Attribute, 1 > formats_rhs;
730
760
SmallVector<StringRef, 1 > empty_format;
731
761
formats_rhs.push_back (builder.getStrArrayAttr (empty_format));
762
+
763
+ SmallVector<Attribute, 1 > blocks_rhs;
764
+ SmallVector<StringRef, 1 > empty_block;
765
+ blocks_rhs.push_back (builder.getStrArrayAttr (empty_block));
732
766
733
767
// / TODO(zpeng): What if the type is not F64?
734
768
mlir::Value compute_rhs = builder.create <indexTree::IndexTreeComputeRHSOp>(
735
769
loc,
736
770
mlir::UnrankedTensorType::get (builder.getF64Type ()),
737
771
tensors_rhs,
738
772
builder.getArrayAttr (indices_rhs),
739
- builder.getArrayAttr (formats_rhs));
773
+ builder.getArrayAttr (formats_rhs),
774
+ builder.getArrayAttr (blocks_rhs));
740
775
741
776
return compute_rhs;
742
777
}
@@ -770,13 +805,19 @@ mlir::Value IndexTreeKernelFusionPass::createResetComputeLHS(
770
805
SmallVector<Attribute, 1 > formats_lhs;
771
806
SmallVector<StringRef, 1 > one_format (rank, " D" );
772
807
formats_lhs.push_back (builder.getStrArrayAttr (one_format));
808
+
809
+ // / Get blocks [["UNK"]]
810
+ SmallVector<Attribute, 1 > blocks_lhs;
811
+ SmallVector<StringRef, 1 > one_block (rank, " UNK" );
812
+ blocks_lhs.push_back (builder.getStrArrayAttr (one_block));
773
813
774
814
mlir::Value compute_lhs = builder.create <indexTree::IndexTreeComputeLHSOp>(
775
815
loc,
776
816
mlir::UnrankedTensorType::get (builder.getF64Type ()),
777
817
tensors_lhs,
778
818
builder.getArrayAttr (indices_lhs),
779
- builder.getArrayAttr (formats_lhs));
819
+ builder.getArrayAttr (formats_lhs),
820
+ builder.getArrayAttr (blocks_lhs));
780
821
781
822
return compute_lhs;
782
823
}
0 commit comments