|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | +#include "../utils.h" |
| 20 | + |
| 21 | +namespace tvm { |
| 22 | +namespace meta_schedule { |
| 23 | + |
| 24 | +class CrossThreadReductionNode : public ScheduleRuleNode { |
| 25 | + public: |
| 26 | + // Inherited from ScheduleRuleNode |
| 27 | + void InitializeWithTuneContext(const TuneContext& context) final { |
| 28 | + ICHECK(context->target.defined()); |
| 29 | + Target target = context->target.value(); |
| 30 | + |
| 31 | + Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block"); |
| 32 | + Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size"); |
| 33 | + |
| 34 | + if (!opt_max_threads_per_block.defined()) { |
| 35 | + LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the " |
| 36 | + "rule CrossThreadReduction will not be applied"; |
| 37 | + } |
| 38 | + if (!opt_warp_size.defined()) { |
| 39 | + LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule " |
| 40 | + "CrossThreadReduction will not be applied"; |
| 41 | + } |
| 42 | + max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; |
| 43 | + warp_size = opt_warp_size.value_or(Integer(-1))->value; |
| 44 | + } |
| 45 | + |
| 46 | + // Inherited from ScheduleRuleNode |
| 47 | + Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { |
| 48 | + // Step 0. Check the conditions of this rule. |
| 49 | + if (max_threads_per_block == -1 || warp_size == -1) { |
| 50 | + return {sch}; |
| 51 | + } |
| 52 | + const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); |
| 53 | + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, |
| 54 | + warp_size)) { |
| 55 | + return {sch}; |
| 56 | + } |
| 57 | + |
| 58 | + // Step 1. Make a copy of the original schedule. The new copy is used for scheduling. |
| 59 | + tir::Schedule tmp_sch = sch->Copy(); |
| 60 | + tmp_sch->Seed(sch->ForkSeed()); |
| 61 | + |
| 62 | + // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the |
| 63 | + // block to its consumers. We want to fuse as much as possible because it results in |
| 64 | + // significantly faster schedule. |
| 65 | + bool fusible = false; |
| 66 | + // `target_loop` is the loop position where the input block will be computed at. |
| 67 | + tir::LoopRV target_loop{nullptr}; |
| 68 | + // `target_block` is the consumer block that we want to compute-at the input block to. |
| 69 | + tir::BlockRV target_block{nullptr}; |
| 70 | + // `tgt_block_innermost_loop` is the innermost loop outside the target block. |
| 71 | + tir::LoopRV tgt_block_innermost_loop{nullptr}; |
| 72 | + |
| 73 | + std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) = |
| 74 | + GetComputeTargetLoopAndBlock(tmp_sch, block_rv); |
| 75 | + |
| 76 | + // Step 3. Try block fusion. |
| 77 | + int n_candidate = static_cast<int>(thread_extents.size()); |
| 78 | + Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); |
| 79 | + tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); |
| 80 | + if (fusible) { |
| 81 | + ICHECK(target_block.defined()); |
| 82 | + ICHECK(target_loop.defined()); |
| 83 | + |
| 84 | + // Step 3.1. |
| 85 | + // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first |
| 86 | + // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split |
| 87 | + // the loop before binding. |
| 88 | + // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. |
| 89 | + if (!InThreadScope(tmp_sch, target_block)) { |
| 90 | + const Array<tir::LoopRV>& split_res = |
| 91 | + tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); |
| 92 | + tmp_sch->Bind(split_res[1], "threadIdx.x"); |
| 93 | + if (tgt_block_innermost_loop.same_as(target_loop)) { |
| 94 | + target_loop = split_res[0]; |
| 95 | + } |
| 96 | + } else { |
| 97 | + thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value()); |
| 98 | + } |
| 99 | + // Step 3.2. Do the compute-at. |
| 100 | + tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true); |
| 101 | + // Step 3.3. Set the storage scope of the output buffer to shared memory. |
| 102 | + tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared"); |
| 103 | + } |
| 104 | + |
| 105 | + // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, |
| 106 | + // fuse all the reduction loops. |
| 107 | + size_t num_spatial_loops; |
| 108 | + tir::LoopRV fused_reduce_loop; |
| 109 | + ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); |
| 110 | + // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. |
| 111 | + const Array<tir::LoopRV>& split_res = |
| 112 | + tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); |
| 113 | + tmp_sch->Bind(split_res[1], "threadIdx.x"); |
| 114 | + |
| 115 | + return {tmp_sch, sch}; |
| 116 | + } |
| 117 | + |
| 118 | + private: |
| 119 | + /*! |
| 120 | + * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is |
| 121 | + * bound to threadIdx. |
| 122 | + * \param sch The TensorIR schedule |
| 123 | + * \param block The block to be checked |
| 124 | + * \return A boolean indicating whether the block is in thread scope. |
| 125 | + */ |
| 126 | + bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { |
| 127 | + const Array<tir::LoopRV>& axes = sch->GetLoops(block); |
| 128 | + for (const tir::LoopRV& loop_rv : axes) { |
| 129 | + const tir::For& loop = sch->Get(loop_rv); |
| 130 | + runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); |
| 131 | + if (tir::IsThreadIdx(thread_scope)) { |
| 132 | + return true; |
| 133 | + } |
| 134 | + } |
| 135 | + return false; |
| 136 | + } |
| 137 | + |
| 138 | + /*! |
| 139 | + * \brief Get the ExprRV which used to define the extent of a given loop. |
| 140 | + * \param trace The trace of the schedule, where the extent is to be found |
| 141 | + * \param loop The loop whose extent is to be found |
| 142 | + * \param extent The finding result |
| 143 | + * \return Whether the find is successful. |
| 144 | + */ |
| 145 | + bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, |
| 146 | + tir::ExprRV* extent) { |
| 147 | + for (const tir::Instruction& inst : trace->insts) { |
| 148 | + if (inst->kind->name == "Split") { |
| 149 | + int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin(); |
| 150 | + CHECK(inst->inputs[1 + i].defined()) |
| 151 | + << "ValueError: Extracting an extent which needs inference is not supported so far"; |
| 152 | + *extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]); |
| 153 | + return true; |
| 154 | + } |
| 155 | + } |
| 156 | + return false; |
| 157 | + } |
| 158 | + |
| 159 | + /*! |
| 160 | + * \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace. |
| 161 | + * \param trace The trace of the schedule, where the extent is to be found |
| 162 | + * \return The extent of "threadIdx.x" in the input schedule |
| 163 | + */ |
| 164 | + tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { |
| 165 | + tir::ExprRV extent{nullptr}; |
| 166 | + for (const tir::Instruction& inst : trace->insts) { |
| 167 | + if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") { |
| 168 | + if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) { |
| 169 | + return extent; |
| 170 | + } |
| 171 | + } |
| 172 | + } |
| 173 | + CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\""; |
| 174 | + throw; |
| 175 | + } |
| 176 | + |
| 177 | + /*! |
| 178 | + * \brief Get the compute-at target loop and the first block under the target loop. |
| 179 | + * \param sch The TensorIR schedule |
| 180 | + * \param block_rv The block whose compute-at target loop is queried |
| 181 | + * \return A tuple consisting of |
| 182 | + * 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible); |
| 183 | + * 2. the compute-at target loop when fusible, or a null loop random variable; |
| 184 | + * 3. the first block under the target loop when fusible, or a null block random variable; |
| 185 | + * 4. the innermost loop outside the target block when fusible, or a null block random variable. |
| 186 | + */ |
| 187 | + std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock( |
| 188 | + const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
| 189 | + // Step 1. Get all the consumers of the input block. |
| 190 | + Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv); |
| 191 | + |
| 192 | + // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is |
| 193 | + // not fusible. |
| 194 | + if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { |
| 195 | + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
| 196 | + tir::LoopRV{nullptr}); |
| 197 | + } |
| 198 | + |
| 199 | + // Step 3. Calculate the lowest common ancestor of all the consumers. |
| 200 | + // - If the lowest common ancestor is a block: |
| 201 | + // - if there is only one consumer, the target block is that consumer; |
| 202 | + // - if there are multiple consumers, they must not share a common loop, and the case is not |
| 203 | + // fusible; |
| 204 | + // - If the lowest common ancestor is a loop, the target block is also the first consumer. |
| 205 | + const tir::StmtSRef& lca_sref = |
| 206 | + tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); |
| 207 | + if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) { |
| 208 | + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
| 209 | + tir::LoopRV{nullptr}); |
| 210 | + } |
| 211 | + |
| 212 | + // Step 4. Get the outer loops of the target block, and get the compute-at position index. |
| 213 | + Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]); |
| 214 | + int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); |
| 215 | + |
| 216 | + // Step 5. A negative position index means not fusible, and vice-versa. |
| 217 | + if (pos < 0) { |
| 218 | + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
| 219 | + tir::LoopRV{nullptr}); |
| 220 | + } else { |
| 221 | + return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); |
| 222 | + } |
| 223 | + } |
| 224 | + |
| 225 | + /*! |
| 226 | + * \brief Get the compute-at position index of the input block, according to |
| 227 | + * 1. the loops outside the input block; |
| 228 | + * 2. the loops outside the target block; |
| 229 | + * 3. the lowest common ancestor of all the consumers of the input block. |
| 230 | + * \param sch The TensorIR schedule |
| 231 | + * \param block_loops The loops outside the input block |
| 232 | + * \param tgt_block_loops The loops outside the target block |
| 233 | + * \param lca_sref The lowest common ancestor of all the consumers of the input block |
| 234 | + * \return The compute-at position index of the input block |
| 235 | + */ |
| 236 | + int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops, |
| 237 | + const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) { |
| 238 | + int n_block_loop = static_cast<int>(block_loops.size()); |
| 239 | + int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size()); |
| 240 | + |
| 241 | + for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { |
| 242 | + if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { |
| 243 | + return i - 1; |
| 244 | + } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { |
| 245 | + // If the lowest common ancestor is a loop, the compute location of the input block should |
| 246 | + // not be deeper than the LCA loop. |
| 247 | + return i; |
| 248 | + } |
| 249 | + } |
| 250 | + return std::min(n_block_loop, n_tgt_block_loop) - 1; |
| 251 | + } |
| 252 | + |
| 253 | + public: |
| 254 | + /*! \brief The maximum number of threads allowed in a thread block */ |
| 255 | + int max_threads_per_block; |
| 256 | + /*! \brief The number of threads per warp */ |
| 257 | + int warp_size; |
| 258 | + /*! \brief Candidates of thread axis extent (values are required to be positive). */ |
| 259 | + Array<Integer> thread_extents; |
| 260 | + |
| 261 | + void VisitAttrs(tvm::AttrVisitor* v) { |
| 262 | + v->Visit("max_threads_per_block", &max_threads_per_block); |
| 263 | + v->Visit("warp_size", &warp_size); |
| 264 | + v->Visit("thread_extents", &thread_extents); |
| 265 | + } |
| 266 | + |
| 267 | + static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; |
| 268 | + TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); |
| 269 | +}; |
| 270 | + |
| 271 | +ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) { |
| 272 | + for (const Integer& extent : thread_extents) { |
| 273 | + CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; |
| 274 | + } |
| 275 | + ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(); |
| 276 | + n->thread_extents = std::move(thread_extents); |
| 277 | + return ScheduleRule(n); |
| 278 | +} |
| 279 | + |
| 280 | +TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); |
| 281 | +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") |
| 282 | + .set_body_typed(ScheduleRule::CrossThreadReduction); |
| 283 | + |
| 284 | +} // namespace meta_schedule |
| 285 | +} // namespace tvm |
0 commit comments