diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index dbbdb5cce..337312851 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -1090,112 +1090,114 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { reducer_info = op->annotations.Get(attr::kReducerInfo) ->as>() .value(); - + if (!result_.for_map.count(tvm::ffi::GetRef(op))) { + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + // the analyzer will be modified in PartitionLoop and VectorizeLoop + // we need to save its state to prevent conflicted bindings + auto saved_analyzer = analyzer_->Clone(); For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - if (result_.for_map.count(tvm::ffi::GetRef(op))) { - auto root = tvm::ffi::GetRef(op); - // This check is a workaround to support T.Parallel for local buffers. - // For example: - // for i in T.Parallel(1024): - // A_local[i] = A_global[i] - // Here, A_local is a register-local buffer held independently by each - // thread, so explicit thread binding is not required. - bool store_into_local = false; - PostOrderVisit(root, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - if (store->buffer.scope() == "local") { - store_into_local = true; - } - // if the case is like: - // for i in T.Parallel(1024): - // A_local[i] = B_global[i] - // A_frag[i] = A_global[i] - // exception will be raise in Parallel::LayoutInference + auto root = tvm::ffi::GetRef(op); + // This check is a workaround to support T.Parallel for local buffers. + // For example: + // for i in T.Parallel(1024): + // A_local[i] = A_global[i] + // Here, A_local is a register-local buffer held independently by each + // thread, so explicit thread binding is not required. + bool store_into_local = false; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (store->buffer.scope() == "local") { + store_into_local = true; } - }); - // This check if for the loop that only manuplates "local" buffers, - // for i in T.Parallel(1024): - // A_local[i] = B_local[i] - // Though this might be illegal - // We use PostOrderVisit to detect whether the loop only manuplates - // "local" buffers, which indicates register usage and justifies skipping - // thread binding. - bool local_register_only = true; - PostOrderVisit(root, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - if (store->buffer.scope() != "local") { - local_register_only = false; - } - } else if (const auto *load = obj.as()) { - if (load->buffer.scope() != "local") { - local_register_only = false; - } + // if the case is like: + // for i in T.Parallel(1024): + // A_local[i] = B_global[i] + // A_frag[i] = A_global[i] + // exception will be raise in Parallel::LayoutInference + } + }); + // This check if for the loop that only manuplates "local" buffers, + // for i in T.Parallel(1024): + // A_local[i] = B_local[i] + // Though this might be illegal + // We use PostOrderVisit to detect whether the loop only manuplates + // "local" buffers, which indicates register usage and justifies skipping + // thread binding. + bool local_register_only = true; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (store->buffer.scope() != "local") { + local_register_only = false; } - }); + } else if (const auto *load = obj.as()) { + if (load->buffer.scope() != "local") { + local_register_only = false; + } + } + }); - auto loop_layout = result_.for_map[root]; - // FIXME: tell in-Parallel and out-of-Parallel `local`s apart - // NOTE(lei): a bit ugly, we should rethink about this part in future. - bool parallel_loop = - !skip_thread_partition_ && !local_register_only && !store_into_local; + auto loop_layout = result_.for_map[root]; + // FIXME: tell in-Parallel and out-of-Parallel `local`s apart + // NOTE(lei): a bit ugly, we should rethink about this part in future. + bool parallel_loop = + !skip_thread_partition_ && !local_register_only && !store_into_local; - if (parallel_loop) { - for_node = - PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); + if (parallel_loop) { + for_node = + PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); + } + // If none thread bindings are provided, partition the loop + bool has_non_local = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *load = obj.as()) { + String scope = load->buffer.scope(); + if (scope != "local" && scope != "local.fragment") { + has_non_local = true; + } + } else if (const auto *store = obj.as()) { + String scope = store->buffer.scope(); + if (scope != "local" && scope != "local.fragment") { + has_non_local = true; + } } - // If none thread bindings are provided, partition the loop - bool has_non_local = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto *load = obj.as()) { - String scope = load->buffer.scope(); - if (scope != "local" && scope != "local.fragment") { - has_non_local = true; - } - } else if (const auto *store = obj.as()) { - String scope = store->buffer.scope(); - if (scope != "local" && scope != "local.fragment") { - has_non_local = true; - } + }); + // Workaround: if reducer is presented, don't vectorize loop + // Best solution should be isolate reduction axis out of vectorization + bool has_reducer = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (!has_reducer) + if (const auto *store = obj.as()) { + has_reducer = reducer_info.count(store->buffer->data) != 0; } - }); - // Workaround: if reducer is presented, don't vectorize loop - // Best solution should be isolate reduction axis out of vectorization - bool has_reducer = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (!has_reducer) - if (const auto *store = obj.as()) { - has_reducer = reducer_info.count(store->buffer->data) != 0; - } - }); - - // If a cast operation exists, vectorization may still be required - bool has_cast_operations = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto *cast = obj.as()) { - // Check if this is a non-reducer store with Cast operation - DataType src_type = cast->value.dtype(); - DataType dst_type = cast->dtype; - bool src_ok = src_type.is_float() || src_type.is_bfloat() || - src_type.is_float8_e4m3() || src_type.is_float8_e5m2(); - bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() || - dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2(); - if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { - has_cast_operations = true; - } + }); + + // If a cast operation exists, vectorization may still be required + bool has_cast_operations = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *cast = obj.as()) { + // Check if this is a non-reducer store with Cast operation + DataType src_type = cast->value.dtype(); + DataType dst_type = cast->dtype; + bool src_ok = src_type.is_float() || src_type.is_bfloat() || + src_type.is_float8_e4m3() || src_type.is_float8_e5m2(); + bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() || + dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2(); + if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { + has_cast_operations = true; } - }); - - if ((has_non_local || has_cast_operations) && !has_reducer) { - for_node = VectorizeLoop(for_node, analyzer_); } + }); - if (result_.predicate_map.count(root) && parallel_loop) { - return IfThenElse(result_.predicate_map[root], for_node); - } else { - return for_node; - } + if ((has_non_local || has_cast_operations) && !has_reducer) { + for_node = VectorizeLoop(for_node, saved_analyzer.get()); + } + + if (result_.predicate_map.count(root) && parallel_loop) { + return IfThenElse(result_.predicate_map[root], for_node); + } else { + return for_node; } - return for_node; } Stmt VisitStmt_(const AttrStmtNode *op) final {