-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathallocation.cpp
1793 lines (1608 loc) · 68.5 KB
/
allocation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <bfs.h>
#include <device_lower/lower2device.h>
#include <device_lower/pass/allocation.h>
#include <expr_evaluator.h>
#include <expr_simplifier.h>
#include <id_model/utils.h>
#include <instrumentation.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <kernel_ir.h>
#include <kernel_ir_dispatch.h>
#include <unordered_set>
namespace nvfuser {
namespace {
// True if a given domain is a loop domain of a given tensor and its
// loop is partitioned with respect to the memory type of the tensor
bool isPartitionedLoop(const TensorView* tv, IterDomain* id) {
// False if id is not a loop ID
if (std::find(tv->getLoopDomain().begin(), tv->getLoopDomain().end(), id) ==
tv->getLoopDomain().end()) {
return false;
}
// If the memory of this domain is partitioned with respect to the
// parallel type of the domain, there's no allocation for the domain
return ir_utils::isMemoryPartitionedAcross(
tv->getMemoryType(), id->getParallelType());
}
bool isSizeOneDomain(IterDomain* id) {
return id->isBroadcast() || id->extent()->isOneInt();
}
// True if a given domain of a tensor *may* require allocation
bool mayRequireAllocation(const TensorView* tv, IterDomain* id) {
// Conditions to consider:
// - Fully partitioned
// - Size one: Allocation is done based on the promotion ID, but as
// long as the original ID has size one, its allocation should
// remain size one.
// - Reduction: Check the original ID, not the promotion, which may
// be a reduction ID even though the original ID is not a reduction
return !isPartitionedLoop(tv, id) && !isSizeOneDomain(id) &&
!id->isReduction() && !id->isStride();
}
// Get the allocation stride of a given allocation domain
Val* getStrideOfGlobalMemoryTensor(TensorView* tv, int64_t alloc_dim) {
NVF_ERROR(tv->getMemoryType() == MemoryType::Global);
// Allocation domains can include reduction domains, but
// alloc_stride arrays do not.
const auto& alloc_dom = tv->getMaybeAllocationDomain();
int64_t stride_dim = -1;
for (const auto i : arange(alloc_dim + 1)) {
if (alloc_dom.at(i)->isReduction()) {
continue;
}
++stride_dim;
}
NVF_ERROR(stride_dim != -1);
return IrBuilder::getItemExpr(
IrBuilder::getAttrExpr(IrBuilder::metadataExpr(tv), "alloc_stride"),
stride_dim);
}
// Preparing allocation info for indexing. Because of broadcasting,
// just looking at the loop groups of a tensor may not be enough to
// determine the allocation of the tensor. For example, this happens
// when a tensor is broadcast and inlined, where the original
// pre-broadcast tensor may not have corresponding domains. If that
// missing domain is annotated with ParallelType::Unroll, which
// affects all inner loops, just looking at the inlined tensor itself
// would miss the unrolling. Since unrolling changes allocation
// shapes, missing unroll can result in incorrect allocations.
//
// TODO: Refactor this and the allocation lowering pass
class AllocationDomainSetup : private kir::IrVisitor {
public:
using IrVisitor::dispatch;
// Set allocation domain info for all tensors
void setup(const std::vector<Expr*>& exprs) {
// Find out correct allocation domains for all consumer
// tensors. Input tensors are handled after this
for (auto expr : exprs) {
dispatch(expr);
}
// Make sure all tensors have allocation domains
for (TensorView* producer_tv : used_as_producer) {
auto it = tv_alloc_info_map.find(producer_tv);
if (it != tv_alloc_info_map.end()) {
continue;
}
// Not yet set. This must be an input tensor or it must be aliased via
// aliasTensorProducer, in which case it will not be allocated.
NVF_ERROR(
producer_tv->isFusionInput() ||
GpuLower::current()->getTensorProducerAlias(producer_tv) !=
nullptr,
"Expected a fusion input or aliased tensor but found: ",
producer_tv->toString());
// For fusion input, we can just use getMaybeAllocationDomain.
auto alloc_info = getAllocationDomainInfo(
producer_tv,
producer_tv->getMaybeAllocationDomain(),
producer_tv->domain()->contiguity());
tv_alloc_info_map.emplace(producer_tv, alloc_info);
}
}
void dispatch(Expr* expr) override {
if (ir_utils::isTvOp(expr)) {
for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
// Note that since we are dealing with a Kernel IR, a single
// tensor may show up as consumers multiple times, e.g.,
// zero initialization and actual definition. Using the last
// expr should give us correct allocation info. See
// IndexingTest.InlinedUnroll for a concrete
// example. Specifically, the initization expression of t2
// doesn't have an unrolling loop, so the allocation info
// obtained from that expression would fail to give the
// correct allocation domains.
auto [alloc_domains, contiguity] =
getAllocationDomainsAndContiguity(out_tv, for_loops_);
auto alloc_info =
getAllocationDomainInfo(out_tv, alloc_domains, contiguity);
tv_alloc_info_map[out_tv] = alloc_info;
}
for (auto in_tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
used_as_producer.insert(in_tv);
}
} else {
IrVisitor::dispatch(expr);
}
}
// Get the allocation domains and contiguity of a given tensor
//
// TODO: Ideally, all tensors should have their correct allocation
// domains, but that isn't always the case at this moment. The logic
// here is duplicated in multiple locations and should be cleaned up.
std::pair<std::vector<IterDomain*>, std::vector<std::optional<bool>>>
getAllocationDomainsAndContiguity(
TensorView* tv,
const std::vector<ForLoop*>& for_loops) {
std::vector<IterDomain*> allocation_domains;
std::vector<std::optional<bool>> contiguity;
// In general, if the tensor has an allocation domain set, it
// should be used with no change. However, set allocation domains
// are not always right allocation domains. For example,
// AliasTest.NotAllOutputAlias_Reduction has a tensor, tv6, that
// is a Local tensor with CA position of 4 but has an allocation
// domain that's just a permutation of its logical domain. Such
// invalid allocations need to be ignored. If there doesn't seem
// to be any clear condition when the set domain can be used, so
// it needs to be inferred. Here's what seems to be working
// reasonably well.
bool use_set_allocation_domain = false;
if (tv->hasAllocation()) {
// Honor the allocation domain if the tensor is global or Hopper MMA's
// output
if (tv->getMemoryType() == MemoryType::Global ||
(tv->definition()->isA<MmaOp>() &&
isHopper(tv->definition()->as<MmaOp>()->macro()))) {
use_set_allocation_domain = true;
} else if (tv->getMemoryType() == MemoryType::Shared) {
// If it's a shared memory tensor, the set domain is likely
// valid if Swizzle or Bulk is used. Also, if the allocation
// domain is just a permutation of the loop domain, use the
// set allocation domain. This seems to happen only with
// AllocationDomainTest.TransposedIntermediate.
if (std::any_of(
tv->getAllocationDomain().begin(),
tv->getAllocationDomain().end(),
[](IterDomain* allocation_domain) {
return dynamic_cast<Swizzle*>(
allocation_domain->definition()) != nullptr ||
allocation_domain->getParallelType() ==
ParallelType::Bulk;
}) ||
std::is_permutation(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
tv->getAllocationDomain().begin())) {
use_set_allocation_domain = true;
}
// Honor the set allocation domain if the tensor is used by a
// TMA store or MmaOp
if (std::ranges::any_of(tv->uses(), [](Expr* expr) {
return ir_utils::isCpAsyncBulkStore(expr) || expr->isA<MmaOp>();
})) {
use_set_allocation_domain = true;
}
}
}
// Allocation position is not always the same as the CA
// position. See also lower_utils::getAllocInformation.
int64_t allocation_pos =
lower_utils::getAllocPosInfo(tv, for_loops).alloc_pos;
if (use_set_allocation_domain) {
if (tv->getMemoryType() == MemoryType::Global) {
// For global memory tensors we always allocate the entire tensor
// TODO: do we really want to treat global memory tensors differently?
// need to think about this more.
allocation_domains = tv->getAllocationDomain();
contiguity = tv->domain()->contiguity();
} else {
std::unordered_set<IterDomain*> exclude_ca_ids;
for (auto i : arange(allocation_pos)) {
auto ca_id = tv->axis(i);
if (!ir_utils::isMemorySharedAcross(
tv->getMemoryType(), ca_id->getParallelType())) {
exclude_ca_ids.insert(ca_id);
}
}
for (auto i : arange(tv->getAllocationDomain().size())) {
auto id = tv->getAllocationDomain()[i];
if (exclude_ca_ids.find(id) == exclude_ca_ids.end()) {
if (ir_utils::isMemoryPartitionedAcross(
tv->getMemoryType(), id->getParallelType())) {
continue;
}
allocation_domains.push_back(id);
contiguity.push_back(tv->domain()->contiguity()[i]);
} else {
exclude_ca_ids.erase(id);
}
}
NVF_ERROR(
exclude_ca_ids.empty(),
"The non-allocating compute-at IDs are not found in the allocation domain. ",
"It is unclear how to allocate the tensor: ",
tv->toString(),
" allocation domain: ",
ir_utils::toString(tv->getAllocationDomain()));
}
} else {
// If allocation domain is not set, assume that:
// - Global: logical domains
// - Local/Shared: loop domains to the right of the CA position
if (tv->getMemoryType() == MemoryType::Global) {
allocation_domains = tv->getLogicalDomain();
contiguity = tv->domain()->contiguity();
} else {
for (const auto i : arange(tv->nDims())) {
auto loop_id = tv->getLoopDomain().at(i);
auto pt = loop_id->getParallelType();
// If the position is left of the inlining position, no need to
// allocate the domain unless it's shared. For example, if this
// is a Shared tensor and the domain is parallelized with TID,
// even if it's outside of the CA position, since the domain
// is shared, it must be allocated.
if (i < allocation_pos &&
!ir_utils::isMemorySharedAcross(tv->getMemoryType(), pt)) {
continue;
}
allocation_domains.push_back(loop_id);
}
// Assume Local and Shared are always fully contiguous
contiguity =
std::vector<std::optional<bool>>(allocation_domains.size(), true);
}
if (auto indexed_alloc_dom =
patchAllocationOfIndexedProducerTensor(tv, allocation_domains);
indexed_alloc_dom.has_value()) {
allocation_domains = indexed_alloc_dom.value();
// Make sure the original allocation domains are fully contiguous
NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
return b.has_value() && b.value();
}));
// Set the new allocation domains fully contiguous
contiguity =
std::vector<std::optional<bool>>(allocation_domains.size(), true);
}
// reorderAllocationDomains and
// patchAllocationOfTransposedSmemTensor assume unallocated IDs
// are removed
std::vector<IterDomain*> actual_allocation_ids;
std::vector<std::optional<bool>> actual_contiguity;
for (auto [i, id] : enumerate(allocation_domains)) {
if (mayRequireAllocation(tv, id)) {
actual_allocation_ids.push_back(id);
actual_contiguity.push_back(contiguity.at(i));
}
}
std::swap(allocation_domains, actual_allocation_ids);
std::swap(contiguity, actual_contiguity);
if (auto reordered_domains =
reorderAllocationDomains(tv, allocation_domains);
reordered_domains.has_value()) {
allocation_domains = reordered_domains.value();
NVF_ERROR(
std::all_of(
contiguity.begin(),
contiguity.end(),
[](auto b) { return b.has_value() && b.value(); }),
tv->toString());
}
// WAR for transpose
if (auto transposed_smem_alloc_dom =
patchAllocationOfTransposedSmemTensor(
tv,
allocation_domains,
GpuLower::current()->idModel().idGraph(IdMappingMode::EXACT));
transposed_smem_alloc_dom.has_value()) {
allocation_domains = transposed_smem_alloc_dom.value();
// Make sure the original allocation domains are fully contiguous
NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
return b.has_value() && b.value();
}));
// Set the new allocation domains fully contiguous
contiguity =
std::vector<std::optional<bool>>(allocation_domains.size(), true);
}
}
NVF_ERROR(allocation_domains.size() == contiguity.size());
return {allocation_domains, contiguity};
}
// Get allocation info necessary for allocation and indexing. Loop promotion
// is considered. Strides are also calculated.
AllocationDomainInfo getAllocationDomainInfo(
TensorView* tv,
std::vector<IterDomain*> allocation_domains,
std::vector<std::optional<bool>> contiguity) {
const IdModel& id_model = GpuLower::current()->idModel();
std::vector<IterDomain*> promoted_allocation_domains;
promoted_allocation_domains.reserve(allocation_domains.size());
// Loop promotion may affect allocations. Promotions of intermediate
// domains may not be defined correctly. Only consider loop domains
// for now.
for (const auto& allocation_domain : allocation_domains) {
bool is_loop = std::find(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
allocation_domain) != tv->getLoopDomain().end();
IterDomain* promotion_domain = nullptr;
if (is_loop) {
promotion_domain = getLoopPromotion(allocation_domain, id_model);
} else {
promotion_domain = allocation_domain;
}
promoted_allocation_domains.push_back(promotion_domain);
}
// Compute the strides from innermost to outermost domains
std::vector<Val*> strides(allocation_domains.size(), nullptr);
Val* cur_contig_stride = tv->fusion()->oneVal();
for (const auto i : arange(allocation_domains.size())) {
auto dim = allocation_domains.size() - i - 1;
auto allocation_domain = allocation_domains.at(dim);
auto promotion_domain = promoted_allocation_domains.at(dim);
if (!mayRequireAllocation(tv, allocation_domain)) {
continue;
}
const std::optional<bool> contig_flag = contiguity.at(dim);
// Broadcast doesn't have contig flag but it must have been
// already filtered out
NVF_ERROR(contig_flag.has_value());
if (contig_flag.value()) {
strides[dim] = cur_contig_stride;
cur_contig_stride = SimplifyingIrBuilder::mulExpr(
cur_contig_stride, promotion_domain->extent());
} else {
// Assume that the tensor should always be a Global memory
// tensor if it has non-contig allocation domains
NVF_ERROR(tv->getMemoryType() == MemoryType::Global);
strides[dim] = getStrideOfGlobalMemoryTensor(tv, (int64_t)dim);
cur_contig_stride = SimplifyingIrBuilder::mulExpr(
strides[dim], promotion_domain->extent());
}
}
// Filter out non-allocated domains
std::vector<IterDomain*> actual_allocation_domains;
std::vector<Val*> actual_strides;
std::vector<bool> actual_contiguity;
for (const auto i : arange(allocation_domains.size())) {
auto allocation_domain = allocation_domains.at(i);
auto promotion_domain = promoted_allocation_domains.at(i);
if (!mayRequireAllocation(tv, allocation_domain)) {
continue;
}
auto stride = strides.at(i);
NVF_ERROR(stride != nullptr);
actual_allocation_domains.push_back(promotion_domain);
actual_strides.push_back(stride);
auto contig = contiguity.at(i);
NVF_ERROR(contig.has_value());
actual_contiguity.push_back(contig.value());
}
NVF_ERROR(actual_allocation_domains.size() == actual_strides.size());
NVF_ERROR(actual_allocation_domains.size() == actual_contiguity.size());
return AllocationDomainInfo{
actual_allocation_domains, actual_strides, actual_contiguity};
}
// Reorder non-logical allocation domains to follow the ordering of
// the set allocation domain. This is necessary when an allocation
// domain includes a vectorized loop iter domain since it must be at the
// innermost position but that may not be the case in the loop
// domain. It is also necessary when the tensor is a producer of a
// vectorized store. Not strictly necessary otherwise, but this should also
// minimize the deviation from the old indexing scheme which always
// uses the logical domain to index.
//
// Returns reordered allocation domains if reordering is done.
std::optional<std::vector<IterDomain*>> reorderAllocationDomains(
const TensorView* tv,
const std::vector<IterDomain*>& allocation_domains) const {
// Use getMaybeAllocationDomain instead of getLogicalDomain. When
// this tv is a producer of a vectorized store, the consumer
// tensor shoud be a global memory tensor and this is likely a
// cache tensor created by cacheBefore. The consumer tensor may
// have a reordered allocation domain and that dictates the actual
// allocation ordering of this producer local tensor as well. If
// getLogicalDomain is used, DistributedTransformerTest.Backward
// fails at the result validation.
auto exprs = DependencyCheck::getAllExprsBetween(
{tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end()},
{allocation_domains.begin(), allocation_domains.end()});
if (exprs.empty()) {
return std::nullopt;
}
// Replay exprs from the logical domain to get the non-reordered
// domains
auto ordered_domains = tv->getMaybeAllocationDomain();
for (auto expr : exprs) {
// Find the position to insert the outputs.
int64_t insertion_pos = -1;
for (auto inp : expr->inputs()) {
auto it =
std::find(ordered_domains.begin(), ordered_domains.end(), inp);
if (it == ordered_domains.end()) {
continue;
}
// Insert right after the input
int64_t pos = std::distance(ordered_domains.begin(), it) + 1;
if (insertion_pos == -1 || pos > insertion_pos) {
insertion_pos = pos;
}
}
NVF_ERROR(
insertion_pos >= 0,
"Failed to replay: ",
expr->toString(),
" in ",
tv->toString());
// Insert the outputs
for (auto out : expr->outputs()) {
ordered_domains.insert(
ordered_domains.begin() + insertion_pos, out->as<IterDomain>());
++insertion_pos;
}
// Delete the inputs
for (auto inp : expr->inputs()) {
auto it =
std::find(ordered_domains.begin(), ordered_domains.end(), inp);
if (it == ordered_domains.end()) {
continue;
}
ordered_domains.erase(it);
}
}
// At this point, all domains of allocation_domains must exist in
// domains.
for (auto alloc_dom : allocation_domains) {
auto it =
std::find(ordered_domains.begin(), ordered_domains.end(), alloc_dom);
NVF_ERROR(
it != ordered_domains.end(),
"Missing allocation domain: ",
alloc_dom->toString(),
", domains: ",
toDelimitedString(ordered_domains));
}
// Pick only the allocation domains from the ordered domains
std::vector<IterDomain*> reordered_allocation_domains;
reordered_allocation_domains.reserve(allocation_domains.size());
for (auto dom : ordered_domains) {
auto it =
std::find(allocation_domains.begin(), allocation_domains.end(), dom);
if (it == allocation_domains.end()) {
continue;
}
reordered_allocation_domains.push_back(dom);
}
// If it's the same order, just return nullopt to tell nothing
// needs to be reordered
if (reordered_allocation_domains == allocation_domains) {
return std::nullopt;
}
return reordered_allocation_domains;
}
// Transpose with shared memory may need to change the ordering of
// allocation domains when shared memory is used as an input to
// vectorized stores. The transpose scheduler stages data to shared
// memory for vectorized stores to global memory. The layout of the
// shared memory staging buffer needs to be compatible with the
// vectorized stores. More specifically, here's a typical pattern of
// the transpose scheduler:
//
// t0_g: [I0, I1]
// t1_l = transpose(0, 1); // [I1, I0]
// t2_s = t1_l; // [I1, I0]
// t3_g = t2_s; // [I1, I0]
//
// t0, t1, t2:
// split I0 by 32 -> I/32a, 32a
// split I1 by 32 -> I/32b, 32b
// merge 32a and 32b -> 32a*32b
// split 32a*32b by 4 -> 32a*32b/4, 4
// -> loop domain: [I0/32a, I1/32b, 32a*32b/4, 4]
// t3:
// split I0 by 32 -> I/32a, 32a
// split I1 by 32 -> I/32b, 32b
// merge 32b and 32a -> 32b*32a
// split 32*32 by 4 -> 32b*32a/4, 4
// -> loop domain: [I0/32a, I1/32b, 32b*32a/4, 4]
//
// Notice that t2 has 32a*32b, whereas t3 has 32b*32a. When the innermost
// domain of t3 is vectorized, this means that 32a must be the
// innermost in the allocation domain of t2. However, the inferred
// allocation domain has [..., 32a*32b/4, 4], so 32a is not the
// innermost.
//
// When a given tensor is found to have this pattern, allocation
// domains as ordered in the same way as the vectorized global
// memory tensor are returned. In the case of the above example,
// [32b, 32a] is returned.
std::optional<std::vector<IterDomain*>> patchAllocationOfTransposedSmemTensor(
const TensorView* tv,
const std::vector<IterDomain*>& allocation_domains,
const ValGraph& exact_graph) const {
// First, do pattern matching to see if this tensor is a shared
// memory tensor transpose. Pattern matching conditions include:
//
// - Shared memory tensor
// - BID/DID should not be used with allocation domains
// - Consumer tensor must be a global memory tensor with vectorization
// - There must be a merge op whose two outputs are the dominating
// domains of the allocation domains
// - The consumer tensor also has a merge but with the inner and
// outer reversed
if (allocation_domains.empty()) {
return std::nullopt;
}
if (tv->getMemoryType() != MemoryType::Shared) {
return std::nullopt;
}
// No BID/DID parallel type should be used
if (std::any_of(
allocation_domains.begin(),
allocation_domains.end(),
[](IterDomain* id) -> bool {
return isParallelTypeDeviceDim(id->getParallelType()) ||
isParallelTypeBlockDim(id->getParallelType());
})) {
return std::nullopt;
}
// Can there be multiple stores with a single smem buffer?
if (tv->uses().size() != 1) {
return std::nullopt;
}
auto ls_op = dynamic_cast<LoadStoreOp*>(tv->uses().front());
if (ls_op == nullptr) {
return std::nullopt;
}
auto consumer = ls_op->out()->as<TensorView>();
if (consumer->getMemoryType() != MemoryType::Global) {
return std::nullopt;
}
IterDomain* consumer_vectorized_domain = nullptr;
if (auto it = std::find_if(
consumer->getLoopDomain().begin(),
consumer->getLoopDomain().end(),
[](IterDomain* loop_id) {
return loop_id->getParallelType() == ParallelType::Vectorize;
});
it != consumer->getLoopDomain().end()) {
consumer_vectorized_domain = *it;
} else {
return std::nullopt;
}
// May be naive, but assume a simple pattern that all allocation
// domains are derived from a merge.
// First, find the closest merge
auto getOriginatingMerge = [](IterDomain* id) -> Merge* {
while (id != nullptr) {
auto def = id->definition();
if (auto merge = dynamic_cast<Merge*>(def)) {
return merge;
} else if (auto split = dynamic_cast<Split*>(def)) {
id = split->in();
} else {
// Unsupported op
return nullptr;
}
}
return nullptr;
};
Merge* producer_common_merge =
getOriginatingMerge(allocation_domains.front());
if (producer_common_merge == nullptr) {
return std::nullopt;
}
// Test if all allocation domains and the merge output are
// equivalent
auto producer_merge_dep_exprs = DependencyCheck::getAllExprsBetween(
{producer_common_merge->out()},
{allocation_domains.begin(), allocation_domains.end()});
std::unordered_set<IterDomain*> equiv_domain_set(
allocation_domains.begin(), allocation_domains.end());
// Traverse back from the allocation domains to the merge output
// and see if they are equivalent
for (auto it = producer_merge_dep_exprs.rbegin();
it != producer_merge_dep_exprs.rend();
++it) {
Expr* expr = *it;
for (auto out : expr->outputs()) {
auto it = equiv_domain_set.find(out->as<IterDomain>());
if (it == equiv_domain_set.end() &&
mayRequireAllocation(tv, out->as<IterDomain>())) {
// missing dependency
return std::nullopt;
}
if (it != equiv_domain_set.end()) {
equiv_domain_set.erase(it);
}
}
for (auto input : expr->inputs()) {
equiv_domain_set.insert(input->as<IterDomain>());
}
}
// If they are equivalent, the merge output should be the only
// remaining domain
if (!(equiv_domain_set.size() == 1 &&
*(equiv_domain_set.begin()) == producer_common_merge->out())) {
// Not all allocation domains are used, meaning the merge output
// is not equivalent to the allocation domains
return std::nullopt;
}
// Look for a reverse merge in the consumer that uses the same
// inputs but outer and inner are reversed
IterDomain* merge_outer = producer_common_merge->outer();
const ValGroup& merge_outer_group = exact_graph.toGroup(merge_outer);
IterDomain* merge_inner = producer_common_merge->inner();
const ValGroup& merge_inner_group = exact_graph.toGroup(merge_inner);
const ExprGroups& merge_outer_uses = exact_graph.getUses(merge_outer_group);
ExprGroup reverse_merge;
for (const auto& merge_outer_use : merge_outer_uses) {
Merge* merge = dynamic_cast<Merge*>(merge_outer_use->front());
if (merge == nullptr) {
continue;
}
if (exact_graph.toGroup(merge->outer()) == merge_inner_group &&
exact_graph.toGroup(merge->inner()) == merge_outer_group) {
reverse_merge = merge_outer_use;
break;
}
}
if (reverse_merge.get() == nullptr) {
return std::nullopt;
}
ValGroup reverse_merge_output =
exact_graph.outputGroups(reverse_merge).at(0);
// Look for a matching merge in the consumer
const auto consumer_all_ids = consumer->domain()->allIDs();
IterDomain* consumer_merge_out = nullptr;
for (auto consumer_id : consumer_all_ids) {
if (reverse_merge_output->has(consumer_id)) {
consumer_merge_out = consumer_id;
break;
}
}
if (consumer_merge_out == nullptr) {
return std::nullopt;
}
// If there's a loop id that depends on consumer_merge_output, the
// producer tensor needs to use the memory layout that works for
// the vectorized store of the consumer tensor.
if (!DependencyCheck::isDependencyOf(
consumer_merge_out, consumer_vectorized_domain)) {
return std::nullopt;
}
std::vector<IterDomain*> patched_allocation_domains{
merge_inner, merge_outer};
return patched_allocation_domains;
}
// If a producer tensor is accessed through supplied indices, the
// indexed logical IDs need to be entirely allocated.
std::optional<std::vector<IterDomain*>> patchAllocationOfIndexedProducerTensor(
const TensorView* tv,
const std::vector<IterDomain*>& allocation_ids) const {
VectorOfUniqueEntries<Val*> indexed_logical_ids;
for (auto use_expr : tv->uses()) {
auto indexed_id = ir_utils::getIndexedProducerID(use_expr);
if (indexed_id == nullptr ||
std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
indexed_id) == tv->getLogicalDomain().end()) {
continue;
}
// This indexed_id is indirectly accessed and needs to be
// allocated entirely.
// If it's already in the allocation ID set, nothing further
// needs to be done
if (std::find(allocation_ids.begin(), allocation_ids.end(), indexed_id) !=
allocation_ids.end()) {
continue;
}
indexed_logical_ids.pushBack(indexed_id);
}
if (indexed_logical_ids.empty()) {
return std::nullopt;
}
// indexed_logical_ids is not in the current allocation ID
// list. Find the allocation IDs that are equivalent to the
// indexed IDs. The indexed IDs should be reachable from the
// allocation IDs, and those allocation IDs used in the traversal
// path should be the ones that should be replaced with the
// indexed IDs.
// In order to retain the original ordering of allocation IDs,
// each indexed logical ID is examined one by one. Specifically,
// for each of them, we find the corresponding IDs in the current
// allocation ID vector and replace them with the indexed logical
// ID.
auto patched_allocation_ids = allocation_ids;
for (auto indexed_logical_id : indexed_logical_ids) {
auto [path, all_visited] = getExprsBetween<IRBFS>(
{patched_allocation_ids.begin(), patched_allocation_ids.end()},
{indexed_logical_id},
/*require_all_to_visited=*/false);
NVF_ERROR(
all_visited,
"Failed to infer valid allocation IDs. Indexed logical IDs need to be entirely allocated but not found in the inferred allocation ID set. Indexed logical ID: ",
indexed_logical_id->toString(),
". Allocation IDs: ",
toDelimitedString(patched_allocation_ids));
auto dependent_allocation_ids = getInputsOfExprPath<IRBFS>(path);
// Insert indexed_logical_id at the innermost position of
// dependent_allocation_ids.
int num_dependent_allocation_ids = 0;
std::vector<IterDomain*> pathched_allocation_ids_next;
for (auto id : allocation_ids) {
if (std::find(
dependent_allocation_ids.begin(),
dependent_allocation_ids.end(),
id) != dependent_allocation_ids.end()) {
++num_dependent_allocation_ids;
if (num_dependent_allocation_ids ==
std::ssize(dependent_allocation_ids)) {
pathched_allocation_ids_next.push_back(
indexed_logical_id->as<IterDomain>());
}
} else {
pathched_allocation_ids_next.push_back(id);
}
}
std::swap(patched_allocation_ids, pathched_allocation_ids_next);
}
return patched_allocation_ids;
}
std::unordered_map<TensorView*, AllocationDomainInfo> tv_alloc_info_map;
std::unordered_set<TensorView*> used_as_producer;
};
} // namespace
namespace {
enum class CircularBufferWaitType { ReadAfterWrite, WriteAfterRead };
// This function creates kir::Loop with range based on stage depth. It is
// used for mbarrier initialization and invalidation.
ForLoop* createStageDepthForLoop(ForLoop* circular_buffer_loop) {
int64_t stage_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
.stage;
return ir_utils::createRangeLoop(stage_depth);
}
// This helper function initializes mbarrier for all circular buffer stage.
//
// Expected result:
// for (unsigned i = 0; i < stages; ++i) {
// if (warp_id == 0 && electSync()()) {
// mbarrier::init(...);
// }
// }
Expr* initializeMbarrier(
ForLoop* circular_buffer_loop,
TensorView* all_mbarriers,
CircularBufferWaitType wait_type) {
NVF_ERROR(circular_buffer_loop != nullptr);
ForLoop* loop = createStageDepthForLoop(circular_buffer_loop);
int64_t stage_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
.stage;
// We use mbarrier[0:stage_depth] for RAW, and
// mbarrier[stage_depth:2*stage_depth] for WAR.
Val* mbarrier_index = wait_type == CircularBufferWaitType::ReadAfterWrite
? loop->index()
: SimplifyingIrBuilder::addExpr(loop->index(), stage_depth);
// Get mbarrier for this circular buffer stage.
kir::TensorIndex* stage_mbarrier =
IrBuilder::create<kir::TensorIndex>(all_mbarriers, mbarrier_index);
auto circular_buffered_tvs =
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
circular_buffer_loop);
Val* num_of_arrives = nullptr;
if (wait_type == CircularBufferWaitType::ReadAfterWrite) {
// The mbarrier of RAW is used to wait for the completion of the TMA
// load of the circular buffer tensor. The number of arrives is the
// number of TMA issued for the circular buffer tensor.
int64_t num_of_tvs_loaded_by_tma = std::count_if(
circular_buffered_tvs.begin(),
circular_buffered_tvs.end(),
[](const TensorView* tv) {
return ir_utils::isCpAsyncBulkLoad(tv->definition());
});
num_of_arrives =
IrBuilder::create<Val>(num_of_tvs_loaded_by_tma, DataType::UInt32);
} else {
// The mbarrier of WAR is used to wait for the completion of the reading
// of the circular buffer tensor. The number of arrives is the number of
// threads in the CTA.
num_of_arrives = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32,
GpuLower::current()
->parallelDimensionMap()
.getNumComputeThreadsEachBlock());
}
// Initialize mbarrier for each circular buffer stage. Use the thread
// count from the MBarrierInit created in the allocation pass. The wait
// condition for mbarrier is a all threads in CTA and the expected number
// of transaction bytes
kir::MBarrierInit* mbarrier_init =
IrBuilder::create<kir::MBarrierInit>(stage_mbarrier, num_of_arrives);
Expr* pred_mbarrier_init = mbarrier_init->withPredicate(
IrBuilder::create<kir::Predicate>(PredicateType::ElectSync));
loop->body().push_back(pred_mbarrier_init);
return loop;
}
// This helper function invalidates mbarrier for all circular buffer stage after
// TMA memory operations.
//
// Expected result:
// for (unsigned i = 0; i < stages; ++i) {
// if (warp_id == 0 && electSync()()) {
// mbarrier::inval(...);
// }
// }
Expr* invalidateMbarrier(
ForLoop* circular_buffer_loop,
TensorView* all_mbarriers,
CircularBufferWaitType wait_type) {
NVF_ERROR(circular_buffer_loop != nullptr);
ForLoop* loop = createStageDepthForLoop(circular_buffer_loop);
int64_t stage_depth =
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(circular_buffer_loop->iter_domain())
.stage;
// We use mbarrier[0:stage_depth] for RAW, and
// mbarrier[stage_depth:2*stage_depth] for WAR.
Val* mbarrier_index = wait_type == CircularBufferWaitType::ReadAfterWrite
? loop->index()
: SimplifyingIrBuilder::addExpr(loop->index(), stage_depth);
// Get mbarrier for this circular buffer stage.
kir::TensorIndex* stage_mbarrier =
IrBuilder::create<kir::TensorIndex>(all_mbarriers, mbarrier_index);
// Invalidate the mbarrier for each circular buffer stage.
kir::MBarrierInvalidate* mbarrier_inval =
IrBuilder::create<kir::MBarrierInvalidate>(stage_mbarrier);
Expr* pred_mbarrier_inval = mbarrier_inval->withPredicate(
IrBuilder::create<kir::Predicate>(PredicateType::ElectSync));
loop->body().push_back(pred_mbarrier_inval);
return loop;
}
class AllocationInserter : public kir::ExprMutator {
private:
using kir::ExprMutator::handle;
// Expanded version of BasicAllocInfo in lower_utils.h helps to track
// additional information
struct AllocationInformation {
// The for loop that the initialization of this allocation must be
// placed in, nullptr if not within a loop
ForLoop* init_for_loop = nullptr;
// The expression that the initialization of this allocation must
// be placed before
Expr* init_place_before = nullptr;
// Keep track of the actual allocation loop. This can be different
// from init_for_loop only with unswitched shared memory allocations,
// which are moved outer loops to avoid duplicated allocations