Skip to content

Commit 3db4f2a

Browse files
committed
Merge remote-tracking branch 'origin/main' into dont_promote_broadcast_only_loop_groups
2 parents 935655e + d525fc3 commit 3db4f2a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+427
-349
lines changed

csrc/alias_analysis.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ std::pair<bool, std::optional<bool>> mergeContiguity(
149149
PairwiseLogicalDomainMap(in, out).mapProducerToConsumer();
150150

151151
Layout preferred_out_layout;
152-
for (const auto i : c10::irange(preferred_in_layout.size())) {
152+
for (const auto i : arange(preferred_in_layout.size())) {
153153
IterDomain* in_alloc_id = preferred_in_layout.allocation_domain[i];
154154
IterDomain* out_root_id = getOrDefault(in_logical_to_out_root, in_alloc_id);
155155
if (out_root_id == nullptr) {
@@ -176,7 +176,7 @@ void AliasFinder::handle(const ViewOp* view) {
176176
}
177177

178178
LinkedHashMap<IterDomain*, std::optional<bool>> allocation_to_contiguity;
179-
for (const auto i : c10::irange(out_root_layout->size())) {
179+
for (const auto i : arange(out_root_layout->size())) {
180180
if (!out_root_layout->contiguity[i].has_value() &&
181181
!out_root_layout->allocation_domain[i]->isBroadcast()) {
182182
// TODO(#1126): Due to #1126, `out_root` materializes an expanded
@@ -352,7 +352,7 @@ void AliasFinder::handle(const BroadcastOp* bcast) {
352352

353353
// Put new, broadcast dimensions to the end.
354354
const std::vector<IterDomain*> out_logical = out->getLogicalDomain();
355-
for (const auto i : c10::irange(out_logical.size())) {
355+
for (const auto i : arange(out_logical.size())) {
356356
if (bcast->isBroadcastDim(i)) {
357357
out_layout->allocation_domain.push_back(out_logical[i]);
358358
out_layout->contiguity.emplace_back(std::nullopt);
@@ -550,7 +550,7 @@ bool Layout::isCompliantWith(const Layout& required) const {
550550
return false;
551551
}
552552

553-
for (const auto i : c10::irange(allocation_domain.size())) {
553+
for (const auto i : arange(allocation_domain.size())) {
554554
if (!contiguityIsCompliant(contiguity[i], required.contiguity[i])) {
555555
return false;
556556
}

csrc/codegen.cpp

+19-22
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ArgumentBuilder {
5757
//! Build an argument list where each argument has its own line
5858
ArgumentBuilder(int indent_level, const char* tab) {
5959
std::stringstream ss;
60-
for (const auto i : c10::irange(indent_level)) {
60+
for (const auto i : arange(indent_level)) {
6161
(void)i; // Suppress unused variable warning
6262
ss << tab;
6363
}
@@ -335,7 +335,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
335335
// Generate parameter declarations
336336
kernel_params_.reserve(kernel_->parameters().size());
337337
unsigned int duplicate_counter = 0;
338-
for (auto i : c10::irange(kernel_->parameters().size())) {
338+
for (auto i : arange(kernel_->parameters().size())) {
339339
std::stringstream var_name_ss;
340340
auto param = kernel_->parameters().at(i);
341341
kernel_params_.insert(param);
@@ -557,7 +557,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
557557
}
558558

559559
std::ostream& indent() {
560-
for (const auto i : c10::irange(block_nest_level_)) {
560+
for (const auto i : arange(block_nest_level_)) {
561561
(void)i; // Suppress unused variable warning
562562
code_ << kTab;
563563
}
@@ -817,7 +817,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
817817
}
818818
auto dtype = std::get<StructType>(sop->output(0)->dtype().type);
819819
code_ << dtype.name << "{ ";
820-
for (auto i : c10::irange(sop->inputs().size())) {
820+
for (auto i : arange(sop->inputs().size())) {
821821
if (i > 0) {
822822
code_ << ", ";
823823
}
@@ -906,7 +906,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
906906
// Generate other datatypes in double
907907
}
908908
code_ << "(" << gen(rop->input(0));
909-
for (auto inp_i : c10::irange(1, rop->inputs().size())) {
909+
for (auto inp_i : arange(1, rop->inputs().size())) {
910910
code_ << ", " << gen(rop->input(inp_i));
911911
}
912912
code_ << ");\n";
@@ -2106,8 +2106,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
21062106
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
21072107

21082108
// Append arguments for each reduction
2109-
for (const auto i :
2110-
c10::irange(grouped_grop->numHorizontallyGroupedExprs())) {
2109+
for (const auto i : arange(grouped_grop->numHorizontallyGroupedExprs())) {
21112110
NVF_ERROR(
21122111
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
21132112
const auto work_buffer =
@@ -2221,7 +2220,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
22212220
for (const auto& index_values : index_val_sets) {
22222221
NVF_ERROR(loop_indices.size() == index_values.size());
22232222
std::unordered_map<const Val*, int64_t> index_val_map;
2224-
for (const auto i : c10::irange(loop_indices.size())) {
2223+
for (const auto i : arange(loop_indices.size())) {
22252224
auto loop_index = loop_indices.at(i);
22262225
auto index_val = index_values.at(i);
22272226
index_val_map.emplace(loop_index, index_val);
@@ -2280,15 +2279,14 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
22802279
ArgumentBuilder write_preds;
22812280

22822281
for (const auto expr_index :
2283-
c10::irange(grouped_grop->numHorizontallyGroupedExprs())) {
2282+
arange(grouped_grop->numHorizontallyGroupedExprs())) {
22842283
const auto data_type = grouped_grop->outputs().at(expr_index)->dtype();
22852284
NVF_ERROR(grouped_grop->reduction_buffers()
22862285
.at(expr_index)
22872286
->buffer()
22882287
->isA<TensorView>());
22892288

2290-
for (const auto& group_index :
2291-
c10::irange(index_replacement_maps.size())) {
2289+
for (const auto& group_index : arange(index_replacement_maps.size())) {
22922290
// Set the index replacement map with the concrete values of
22932291
// indices of grouped loops.
22942292
index_replacement_map_ = index_replacement_maps.at(group_index);
@@ -2422,13 +2420,12 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
24222420
auto init_vals = grouped_gwop->initVals();
24232421

24242422
for (const auto expr_index :
2425-
c10::irange(grouped_gwop->numHorizontallyGroupedExprs())) {
2423+
arange(grouped_gwop->numHorizontallyGroupedExprs())) {
24262424
const auto& output = output_vals.at(expr_index);
24272425
const auto& input = input_vals.at(expr_index);
24282426
const auto& init = init_vals.at(expr_index);
24292427

2430-
for (const auto& group_index :
2431-
c10::irange(index_replacement_maps.size())) {
2428+
for (const auto& group_index : arange(index_replacement_maps.size())) {
24322429
// Set the index replacement map with the concrete values of
24332430
// indices of grouped loops.
24342431
index_replacement_map_ = index_replacement_maps.at(group_index);
@@ -2442,7 +2439,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
24422439
std::to_string(group_index));
24432440

24442441
// Setup arguments for avg, var, and N
2445-
for (const auto i : c10::irange(3)) {
2442+
for (const auto i : arange(3)) {
24462443
out_args[i].arg(gen(output.get(i)));
24472444
in_args[i].arg(gen(input.get(i)));
24482445
init_args[i].arg(gen(init.get(i)));
@@ -2589,7 +2586,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
25892586
func_args.arg(genComputeBlockDim());
25902587

25912588
// global buf
2592-
for (const auto i : c10::irange(3)) {
2589+
for (const auto i : arange(3)) {
25932590
const auto work_buffer = grouped_gwop->reduction_buffers()[i]
25942591
.at(0)
25952592
->buffer()
@@ -3005,7 +3002,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
30053002
grouped_rop->writePredicate());
30063003
}
30073004

3008-
for (const auto i : c10::irange(num_grouped_exprs)) {
3005+
for (const auto i : arange(num_grouped_exprs)) {
30093006
NVF_ERROR(grouped_rop->output(i)->isA<kir::TensorIndex>());
30103007

30113008
const auto output = grouped_rop->output(i)->as<kir::TensorIndex>();
@@ -3267,7 +3264,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
32673264
// Indentation for the PTX code
32683265
int utility_block_nest_level = 1;
32693266
std::function<std::ostream&()> indent_utility = [&]() -> std::ostream& {
3270-
for (auto _ : c10::irange(utility_block_nest_level)) {
3267+
for (auto _ : arange(utility_block_nest_level)) {
32713268
(void)_;
32723269
utilities << kTab;
32733270
}
@@ -3294,7 +3291,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
32943291
if (!asm_->options().immediate_inputs.empty()) {
32953292
utilities << "template <";
32963293
bool first = true;
3297-
for (auto in_i : c10::irange((int64_t)inputs.size())) {
3294+
for (auto in_i : arange((int64_t)inputs.size())) {
32983295
if (asm_->options().immediate_inputs.count(in_i)) {
32993296
if (!first) {
33003297
utilities << ", ";
@@ -3306,7 +3303,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
33063303
utilities << ">\n";
33073304
}
33083305
utilities << "__device__ __inline__ void " << utility_name_no_ns << "(";
3309-
for (auto out_i : c10::irange(outputs.size())) {
3306+
for (auto out_i : arange(outputs.size())) {
33103307
if (out_i > 0) {
33113308
utilities << ", ";
33123309
}
@@ -3315,7 +3312,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
33153312
if (!outputs.empty()) {
33163313
utilities << ", ";
33173314
}
3318-
for (auto in_i : c10::irange((int64_t)inputs.size())) {
3315+
for (auto in_i : arange((int64_t)inputs.size())) {
33193316
if (asm_->options().immediate_inputs.count(in_i)) {
33203317
continue;
33213318
}
@@ -3427,7 +3424,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
34273424
auto reg_dtype = get_type_or_index_type(register_);
34283425
if (std::holds_alternative<ArrayType>(reg_dtype.type)) {
34293426
for (auto i :
3430-
c10::irange(std::get<ArrayType>(reg_dtype.type).size)) {
3427+
arange(std::get<ArrayType>(reg_dtype.type).size)) {
34313428
if (i > 0) {
34323429
next_line();
34333430
}

csrc/compute_at.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include <scheduler/tools/inlining.h>
1616
#include <transform_iter.h>
1717

18-
#include <c10/util/irange.h>
19-
2018
namespace nvfuser {
2119

2220
// Simple selector that only propagates across tensor views in the provided
@@ -66,7 +64,7 @@ std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
6664
std::deque<std::deque<TensorView*>> tvChains(
6765
std::deque<std::deque<Val*>> val_chains) {
6866
std::deque<std::deque<TensorView*>> tv_chains(val_chains.size());
69-
for (const auto i : c10::irange(val_chains.size())) {
67+
for (const auto i : arange(val_chains.size())) {
7068
auto tv_iterable = ir_utils::filterByType<TensorView>(val_chains[i]);
7169
tv_chains[i] =
7270
std::deque<TensorView*>(tv_iterable.begin(), tv_iterable.end());

csrc/compute_at_map.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
216216
first->toString(),
217217
"\nand\n",
218218
second->toString());
219-
for (auto out_i : c10::irange(first_ids.size())) {
219+
for (auto out_i : arange(first_ids.size())) {
220220
exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
221221
permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
222222
permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
@@ -393,7 +393,7 @@ void IterDomainGraph::build(Fusion* fusion) {
393393
// p->f, c->c
394394
std::unordered_map<IterDomain*, IterDomain*> c2f_root_map;
395395
for (const auto i :
396-
c10::irange(first_output_tv->getMaybeRootDomain().size())) {
396+
arange(first_output_tv->getMaybeRootDomain().size())) {
397397
c2f_root_map.insert(std::make_pair(
398398
c_tv->getMaybeRootDomain()[i],
399399
first_output_tv->getMaybeRootDomain()[i]));
@@ -504,7 +504,7 @@ void IterDomainGraph::build(Fusion* fusion) {
504504

505505
for (auto& dset : permissive_disjoint_sets.disjointSets()) {
506506
auto& vec = dset->vector();
507-
for (auto i : c10::irange(vec.size())) {
507+
for (auto i : arange(vec.size())) {
508508
auto id1 = vec[i];
509509
permissive_nodes_.mapEntries(id1, vec[0]);
510510

@@ -513,7 +513,7 @@ void IterDomainGraph::build(Fusion* fusion) {
513513
// or p_id is swizzle output.
514514
mapMaybeSwizzleOp(permissive_nodes_, id1);
515515

516-
for (auto j : c10::irange(i + 1, vec.size())) {
516+
for (auto j : arange(i + 1, vec.size())) {
517517
auto id2 = vec[j];
518518
if (p_ids.count(id1) && c_ids.count(id2)) {
519519
if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) &&
@@ -538,11 +538,11 @@ void IterDomainGraph::build(Fusion* fusion) {
538538
// permissive-resize mappings.
539539
for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) {
540540
auto& vec = dset->vector();
541-
for (auto i : c10::irange(vec.size())) {
541+
for (auto i : arange(vec.size())) {
542542
auto id1 = vec[i];
543543
permissive_resize_nodes_.mapEntries(id1, vec[0]);
544544
mapMaybeSwizzleOp(permissive_resize_nodes_, id1);
545-
for (auto j : c10::irange(i + 1, vec.size())) {
545+
for (auto j : arange(i + 1, vec.size())) {
546546
auto id2 = vec[j];
547547
if (p_ids.count(id1) && c_ids.count(id2)) {
548548
consumers_.at(id1).pushBack(id2);
@@ -651,7 +651,7 @@ void IterDomainGraph::build(Fusion* fusion) {
651651
for (auto prop_forward : {true, false}) {
652652
std::unordered_set<Expr*> visited_exprs;
653653

654-
for (auto logical_id_i : c10::irange(logical_id_order.size())) {
654+
for (auto logical_id_i : arange(logical_id_order.size())) {
655655
auto first_logical_id = prop_forward
656656
? logical_id_order[logical_id_i]
657657
: logical_id_order[logical_id_order.size() - 1 - logical_id_i];
@@ -881,8 +881,8 @@ void ComputeAtMap::allocateIndexVariables() {
881881
// Allocate index variable for each stage of the circular buffered loop.
882882
circular_buffered_loop_index_variable_map_[loop_disjoint_set.get()] =
883883
std::make_unique<CircularBufferIndices>();
884-
for (auto i : c10::irange(
885-
static_cast<int>(CircularBufferLoopStage::EndOfStages))) {
884+
for (auto i :
885+
arange(static_cast<int>(CircularBufferLoopStage::EndOfStages))) {
886886
auto stage = static_cast<CircularBufferLoopStage>(i);
887887
circular_buffered_loop_index_variable_map_[loop_disjoint_set.get()]
888888
->emplace(stage, IrBuilder::create<Val>(DataType::Index));
@@ -1260,7 +1260,7 @@ bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
12601260
expr_1->outputs().size() == expr_2->outputs().size(),
12611261
"Expr traversal doesn't support variable number of inputs and outputs.");
12621262

1263-
for (auto input_i : c10::irange(expr_1->inputs().size())) {
1263+
for (auto input_i : arange(expr_1->inputs().size())) {
12641264
if (expr_1->inputs()[input_i]->isA<IterDomain>() &&
12651265
!areMapped(
12661266
expr_1->inputs()[input_i]->as<IterDomain>(),
@@ -1271,7 +1271,7 @@ bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
12711271
}
12721272
}
12731273

1274-
for (auto output_i : c10::irange(expr_1->outputs().size())) {
1274+
for (auto output_i : arange(expr_1->outputs().size())) {
12751275
if (expr_1->outputs()[output_i]->isA<IterDomain>() &&
12761276
!areMapped(
12771277
expr_1->outputs()[output_i]->as<IterDomain>(),

csrc/contiguity.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ OrderedIdInformation::OrderedIdInformation(
2222
}
2323

2424
// Grab allocation ids and initialize them.
25-
for (const auto alloc_i : c10::irange(alloc_domain.size())) {
25+
for (const auto alloc_i : arange(alloc_domain.size())) {
2626
auto alloc_id = alloc_domain[alloc_i]->as<IterDomain>();
2727

2828
// Initialize id_to_alloc_ids to map allocs to themselves
@@ -508,7 +508,7 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
508508
" != ",
509509
alloc_contiguity_.size());
510510

511-
for (const auto alloc_domain_i : c10::irange(alloc_domain_.size())) {
511+
for (const auto alloc_domain_i : arange(alloc_domain_.size())) {
512512
auto alloc_domain_id = alloc_domain_.at(alloc_domain_i);
513513
if (alloc_domain_id->isBroadcast()) {
514514
NVF_ERROR(!alloc_contiguity_.at(alloc_domain_i).has_value());
@@ -599,7 +599,7 @@ void ContigIDs::handle(Merge* merge) {
599599
bool is_indexing_pass = !ignore_consistent_ordering_;
600600

601601
IterDomain* last_alloc = nullptr;
602-
for (auto alloc_id_i : c10::irange(alloc_domain_.size())) {
602+
for (auto alloc_id_i : arange(alloc_domain_.size())) {
603603
auto alloc_id = alloc_domain_[alloc_id_i];
604604
if (alloc_id->isBroadcast()) {
605605
NVF_ERROR(!alloc_contiguity_.at(alloc_id_i).has_value());

0 commit comments

Comments
 (0)