Skip to content

Commit 725badb

Browse files
jinhongyiiylc
authored andcommitted
[MetaSchedule][M4a] Mutator: Mutate Parallel (apache#10096)
1 parent a41af15 commit 725badb

File tree

8 files changed

+518
-0
lines changed

8 files changed

+518
-0
lines changed

include/tvm/tir/schedule/instruction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object {
121121
// not visited: f_attrs_from_json
122122
}
123123

124+
/*! \brief Checks if the instruction kind is EnterPostproc */
125+
bool IsPostproc() const;
126+
124127
static constexpr const char* _type_key = "tir.InstructionKind";
125128
TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object);
126129
};

python/tvm/meta_schedule/mutator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
"""
2222
from .mutator import Mutator, PyMutator
2323
from .mutate_compute_location import MutateComputeLocation
24+
from .mutate_parallel import MutateParallel
2425
from .mutate_unroll import MutateUnroll
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Mutator that mutates the parallel extent"""
18+
from tvm._ffi.registry import register_object
19+
20+
from .. import _ffi_api
21+
from .mutator import Mutator
22+
23+
24+
@register_object("meta_schedule.MutateParallel")
25+
class MutateParallel(Mutator):
26+
"""Mutator that mutates the parallel extent"""
27+
28+
def __init__(self, max_jobs_per_core: int) -> None:
29+
"""Mutator that mutates the parallel extent"""
30+
self.__init_handle_by_constructor__(
31+
_ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member
32+
max_jobs_per_core,
33+
)
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <algorithm>
20+
#include <unordered_map>
21+
22+
#include "../utils.h"
23+
24+
namespace tvm {
25+
namespace tir {
26+
27+
/*!
28+
* \brief Check if the instruction is annotation with `meta_schedule_parallel`
29+
* \param inst The instruction to be checked
30+
* \return Whether the instruction is annotation with `meta_schedule_parallel`
31+
*/
32+
bool IsAnnotateWithParallel(const Instruction& inst) {
33+
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
34+
if (!inst->kind.same_as(inst_annotate)) {
35+
return false;
36+
}
37+
ICHECK_EQ(inst->attrs.size(), 1);
38+
String ann_key = Downcast<String>(inst->attrs[0]);
39+
return ann_key == attr::meta_schedule_parallel;
40+
}
41+
42+
/*!
43+
* \brief Replace the annotation value
44+
* \param inst The instruction to be replaced
45+
* \param ann_val The new annotation value
46+
* \return The replaced instruction
47+
*/
48+
Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) {
49+
ICHECK_EQ(inst->inputs.size(), 2);
50+
return Instruction(/*kind=*/inst->kind, //
51+
/*inputs=*/{inst->inputs[0], Integer(ann_val)}, //
52+
/*attrs=*/inst->attrs,
53+
/*outputs=*/inst->outputs);
54+
}
55+
56+
/*!
57+
* \brief Get the output of the instruction Get-Block
58+
* \param inst The instruction to be checked
59+
* \return The output of the instruction Get-Block
60+
*/
61+
const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
62+
static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock");
63+
if (!inst->kind.same_as(inst_get_block)) {
64+
return nullptr;
65+
}
66+
ICHECK_EQ(inst->outputs.size(), 1);
67+
const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
68+
return block;
69+
}
70+
71+
/*!
72+
* \brief Analyze the parallel structure
73+
* \param self The schedule state
74+
* \param block_name The name of the root block
75+
* \param func_name The name of the PrimFunc
76+
* \param limit The uplimit of the parallelism
77+
* \return The parallel structure
78+
*/
79+
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
80+
const String& block_name, const String& func_name,
81+
int64_t limit) {
82+
Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
83+
ICHECK_EQ(block_srefs.size(), 1);
84+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
85+
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
86+
std::vector<std::vector<int64_t>> results;
87+
results.reserve(info.realizes.size());
88+
for (const BlockRealize& realize : info.realizes) {
89+
// Step 1. Extract static loop extents for spatial loops
90+
std::vector<int64_t> loop_extents;
91+
const ForNode* loop = nullptr;
92+
for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent;
93+
(loop = loop_sref->StmtAs<ForNode>()) != nullptr; //
94+
loop_sref = loop_sref->parent) {
95+
int64_t loop_extent = -1;
96+
if (const auto* ext = GetLoopIntExtent(loop)) {
97+
if (!info.non_spatial_vars.count(loop->loop_var.get())) {
98+
loop_extent = *ext;
99+
}
100+
}
101+
if (loop_extent != -1) {
102+
loop_extents.push_back(loop_extent);
103+
} else {
104+
loop_extents.clear();
105+
}
106+
}
107+
// Step 2. Take the prefix product of loop extents
108+
if (!loop_extents.empty()) {
109+
results.emplace_back();
110+
std::vector<int64_t>& result = results.back();
111+
result.reserve(loop_extents.size());
112+
int64_t prod_extent = 1;
113+
for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) {
114+
result.push_back(prod_extent *= *it);
115+
if (prod_extent >= limit) {
116+
break;
117+
}
118+
}
119+
}
120+
}
121+
return results;
122+
}
123+
124+
/*!
125+
* \brief Get the number of parallelizable loops for each subtree
126+
* \param loop_extent_prods The parallel structure for each subtree
127+
* \param limit The uplimit of the parallelism
128+
* \return The number of parallelizable loops for each subtree
129+
*/
130+
std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods,
131+
int64_t limit) {
132+
std::vector<int> results;
133+
results.reserve(loop_extent_prods.size());
134+
for (const std::vector<int64_t>& prods : loop_extent_prods) {
135+
int n = prods.size();
136+
int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin();
137+
if (i > 0 && prods[i - 1] == limit) {
138+
--i;
139+
}
140+
if (i != n) {
141+
++i;
142+
}
143+
results.push_back(i);
144+
}
145+
return results;
146+
}
147+
148+
} // namespace tir
149+
} // namespace tvm
150+
151+
namespace tvm {
152+
namespace meta_schedule {
153+
154+
using tir::Instruction;
155+
using tir::Trace;
156+
157+
/*! \brief Create a Mutator that mutates the parallel extent */
158+
class MutateParallelNode : public MutatorNode {
159+
public:
160+
/*!
161+
* \brief The maximum number of jobs to be launched per CPU core.
162+
* It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`.
163+
* Use -1 to disable parallelism.
164+
*/
165+
int64_t max_jobs_per_core;
166+
/*! \brief The number of cores in CPU. */
167+
int max_parallel_extent_;
168+
/*! \brief JSON representation of the workload */
169+
std::string json_mod_;
170+
171+
void VisitAttrs(tvm::AttrVisitor* v) {
172+
v->Visit("max_jobs_per_core", &max_jobs_per_core);
173+
// `max_parallel_extent_` is not visited.
174+
// `json_mod` is not visited.
175+
}
176+
177+
static constexpr const char* _type_key = "meta_schedule.MutateParallel";
178+
TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode);
179+
180+
public:
181+
struct Candidate;
182+
// Inherit from `MutatorNode`
183+
void InitializeWithTuneContext(const TuneContext& context) final {
184+
Target target = context->target.value();
185+
this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core;
186+
this->json_mod_ = SaveJSON(context->mod.value());
187+
}
188+
// Inherit from `MutatorNode`
189+
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
190+
};
191+
192+
/*! \brief The candidate to be mutated */
193+
struct MutateParallelNode::Candidate {
194+
/*! \brief The annotation instruction */
195+
Instruction inst;
196+
/*! \brief The current parallel extent */
197+
int64_t parallel_extent;
198+
/*! \brief The name of the root block */
199+
String block_name;
200+
/*! \brief The name of the PrimFunc */
201+
String func_name;
202+
};
203+
204+
/*!
205+
* \brief Get an instruction that annotates the maximum parallel extent
206+
* \param trace The trace to be mutated
207+
* \param rand_state The random state
208+
* \param candidate The candidate to be mutated
209+
* \return Whether a decision is found
210+
*/
211+
bool FindParallelDecision(const Trace& trace, TRandState* rand_state,
212+
MutateParallelNode::Candidate* candidate) {
213+
using tir::BlockRVNode;
214+
using tir::InstructionNode;
215+
std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts;
216+
std::vector<const InstructionNode*> ann_insts;
217+
get_block_insts.reserve(trace->insts.size());
218+
ann_insts.reserve(trace->insts.size());
219+
for (const Instruction& inst : trace->insts) {
220+
if (tir::IsAnnotateWithParallel(inst)) {
221+
ann_insts.push_back(inst.get());
222+
}
223+
if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) {
224+
get_block_insts[block_rv] = inst.get();
225+
}
226+
}
227+
int n_ann_insts = ann_insts.size();
228+
if (n_ann_insts == 0) {
229+
return false;
230+
}
231+
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
232+
ICHECK_EQ(ann_inst->inputs.size(), 2);
233+
const InstructionNode* get_block_inst =
234+
get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get());
235+
ICHECK_EQ(get_block_inst->attrs.size(), 2);
236+
candidate->inst = GetRef<Instruction>(ann_inst);
237+
candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value;
238+
candidate->block_name = Downcast<String>(get_block_inst->attrs[0]);
239+
candidate->func_name = Downcast<String>(get_block_inst->attrs[1]);
240+
return true;
241+
}
242+
243+
Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) {
244+
// Step 1. Find a parallel decision.
245+
Candidate candidate;
246+
if (!FindParallelDecision(trace, rand_state, &candidate)) {
247+
return NullOpt;
248+
}
249+
// Step 2. Replay the instructions to recover loop extents
250+
tir::Schedule sch = tir::Schedule::Traced( //
251+
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
252+
/*rand_state=*/ForkSeed(rand_state), //
253+
/*debug_mode=*/0,
254+
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
255+
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
256+
// Step 3. Find all possible parallel plans
257+
std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel(
258+
sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_);
259+
std::unordered_map<int64_t, std::vector<int>> limit2plan;
260+
std::map<std::vector<int>, int64_t> plan2limit;
261+
for (const std::vector<int64_t>& prods : loop_extent_prods) {
262+
for (int64_t limit : prods) {
263+
if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) {
264+
std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit);
265+
limit2plan[limit] = plan;
266+
plan2limit[plan] = limit;
267+
}
268+
}
269+
}
270+
// Step 4. Remove the original plan and remove it
271+
std::vector<int> original_plan =
272+
tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent);
273+
auto it = plan2limit.find(original_plan);
274+
if (it != plan2limit.end()) {
275+
plan2limit.erase(it);
276+
}
277+
// Step 5. Pick a new plan
278+
int n_plans = plan2limit.size();
279+
if (n_plans == 0) {
280+
return NullOpt;
281+
}
282+
it = plan2limit.begin();
283+
for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) {
284+
++it;
285+
}
286+
int64_t limit = it->second;
287+
// Step 6. Assemble a new trace
288+
Array<Instruction> insts;
289+
insts.reserve(trace->insts.size());
290+
for (const Instruction& inst : trace->insts) {
291+
if (inst.same_as(candidate.inst)) {
292+
insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit));
293+
} else if (inst->kind->IsPostproc()) {
294+
break;
295+
} else {
296+
insts.push_back(inst);
297+
}
298+
}
299+
return Trace(insts, trace->decisions);
300+
}
301+
302+
Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) {
303+
ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>();
304+
n->max_jobs_per_core = max_jobs_per_core;
305+
return Mutator(n);
306+
}
307+
308+
TVM_REGISTER_NODE_TYPE(MutateParallelNode);
309+
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel);
310+
311+
} // namespace meta_schedule
312+
} // namespace tvm

src/tir/schedule/analysis.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
9191
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
9292
bool require_subtree_compact_dataflow);
9393

94+
/*!
95+
* \brief The information of a block scope, including the leaf blocks,
96+
* as well as the loop types (spatial, reduction) for each loop in the scope.
97+
*/
98+
struct ScopeBlockLoopInfo {
99+
/*! \brief A list of the leaf blocks, from left to right */
100+
std::vector<BlockRealize> realizes;
101+
/*! \brief The loop vars bound to spatial block iters */
102+
std::unordered_set<const VarNode*> spatial_vars;
103+
/*! \brief The loop vars bound to non-spatial block iters */
104+
std::unordered_set<const VarNode*> non_spatial_vars;
105+
};
106+
107+
/*!
108+
* \brief Inspect the scope of the given sref
109+
* \param scope_block The root block of the scope
110+
* \return The information of the scope
111+
*/
112+
ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block);
113+
94114
/*!
95115
* \brief Checks whether the block is a complete block under the scope
96116
* \param self The schedule state

0 commit comments

Comments
 (0)