|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | +#include "../utils.h" |
| 20 | + |
| 21 | +namespace tvm { |
| 22 | +namespace meta_schedule { |
| 23 | + |
| 24 | +using tir::Instruction; |
| 25 | +using tir::InstructionKind; |
| 26 | +using tir::Trace; |
| 27 | + |
| 28 | +/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */ |
| 29 | +class MutateComputeLocationNode : public MutatorNode { |
| 30 | + public: |
| 31 | + /*! \brief JSON representation of the workload */ |
| 32 | + std::string json_mod_; |
| 33 | + |
| 34 | + void VisitAttrs(tvm::AttrVisitor* v) {} |
| 35 | + static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation"; |
| 36 | + TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); |
| 37 | + |
| 38 | + public: |
| 39 | + // Inherit from `MutatorNode` |
| 40 | + void InitializeWithTuneContext(const TuneContext& context) final { |
| 41 | + this->json_mod_ = SaveJSON(context->mod.value()); |
| 42 | + } |
| 43 | + // Inherit from `MutatorNode` |
| 44 | + Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
| 45 | + |
| 46 | + private: |
| 47 | + struct Candidate { |
| 48 | + /*! \brief The SampleComputeLocation instruction */ |
| 49 | + Instruction inst; |
| 50 | + /*! \brief The candidate compute-at locations */ |
| 51 | + std::vector<int> locs; |
| 52 | + |
| 53 | + explicit Candidate(Instruction inst, std::vector<int> locs) |
| 54 | + : inst(std::move(inst)), locs(std::move(locs)) {} |
| 55 | + }; |
| 56 | + |
| 57 | + std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state); |
| 58 | +}; |
| 59 | + |
| 60 | +/*! |
| 61 | + * \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated |
| 62 | + * to at lease one other value |
| 63 | + * \param trace The trace from which to find the instructions |
| 64 | + * \return All the candidate instructions together with the candidate compute-at locations |
| 65 | + */ |
| 66 | +std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates( |
| 67 | + const Trace& trace, TRandState* rand_state) { |
| 68 | + tir::Schedule sch = tir::Schedule::Traced( // |
| 69 | + /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), // |
| 70 | + /*rand_state=*/ForkSeed(rand_state), // |
| 71 | + /*debug_mode=*/0, // |
| 72 | + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); |
| 73 | + |
| 74 | + static InstructionKind inst_sample_compute_location = |
| 75 | + InstructionKind::Get("SampleComputeLocation"); |
| 76 | + std::vector<MutateComputeLocationNode::Candidate> candidates; |
| 77 | + |
| 78 | + auto f_decision_provider = [&](const tir::Instruction& inst, // |
| 79 | + const Array<ObjectRef>& inputs, // |
| 80 | + const Array<ObjectRef>& attrs, // |
| 81 | + const ObjectRef& decision) -> ObjectRef { |
| 82 | + if (inst->kind.same_as(inst_sample_compute_location)) { |
| 83 | + // Step 1. Extract the instruction input and the old decision. |
| 84 | + ICHECK_EQ(inputs.size(), 1); |
| 85 | + tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0])); |
| 86 | + int old_decision = Downcast<Integer>(decision)->value; |
| 87 | + |
| 88 | + // Step 2. Collect all the compute_at locations. |
| 89 | + Array<tir::StmtSRef> location_srefs; |
| 90 | + std::vector<int> location_indices; |
| 91 | + std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref); |
| 92 | + // Step 3. Remove the old decision. |
| 93 | + auto it = std::find(location_indices.begin(), location_indices.end(), old_decision); |
| 94 | + if (it != location_indices.end()) { |
| 95 | + location_srefs.erase(location_srefs.begin() + (it - location_indices.begin())); |
| 96 | + location_indices.erase(it); |
| 97 | + } |
| 98 | + ICHECK_EQ(location_srefs.size(), location_indices.size()); |
| 99 | + // Step 4. Add a new candidate if there are at least one remaining compute-at position. |
| 100 | + if (!location_srefs.empty()) { |
| 101 | + candidates.emplace_back(inst, std::move(location_indices)); |
| 102 | + } |
| 103 | + } |
| 104 | + return decision; |
| 105 | + }; |
| 106 | + trace->ApplyToSchedule(sch, // |
| 107 | + /*remove_postproc=*/true, // |
| 108 | + /*decision_provider=*/f_decision_provider); |
| 109 | + return candidates; |
| 110 | +} |
| 111 | + |
| 112 | +Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { |
| 113 | + std::vector<Candidate> candidates = FindCandidates(trace, rand_state); |
| 114 | + if (candidates.empty()) { |
| 115 | + return NullOpt; |
| 116 | + } |
| 117 | + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; |
| 118 | + int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; |
| 119 | + return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); |
| 120 | +} |
| 121 | + |
| 122 | +Mutator Mutator::MutateComputeLocation() { |
| 123 | + return Mutator(make_object<MutateComputeLocationNode>()); |
| 124 | +} |
| 125 | + |
| 126 | +TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); |
| 127 | +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") |
| 128 | + .set_body_typed(Mutator::MutateComputeLocation); |
| 129 | + |
| 130 | +} // namespace meta_schedule |
| 131 | +} // namespace tvm |
0 commit comments