Skip to content

Commit 6681c7f

Browse files
committed
[MetaSchedule] Postproc: Rewrite-Layout
1 parent fb5e9c9 commit 6681c7f

File tree

7 files changed

+358
-26
lines changed

7 files changed

+358
-26
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,17 @@ class Postproc : public runtime::ObjectRef {
150150
* \return The postprocessor created.
151151
*/
152152
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);
153-
154153
/*!
155154
* \brief Creates a postprocessor that verifies if the GPU code is correct
156155
* \return The postprocessor created
157156
*/
158157
TVM_DLL static Postproc VerifyGPUCode();
158+
/*!
159+
* \brief Creates a postprocessor that rewrites the layout of input tensor
160+
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
161+
* \return The postprocessor created
162+
*/
163+
TVM_DLL static Postproc RewriteLayout();
159164
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
160165
};
161166

python/tvm/auto_scheduler/__init__.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,64 @@
1717
# pylint: disable=unused-import, redefined-builtin
1818
""" Namespace for TVM Auto-scheduler. """
1919

20-
from . import compute_dag
21-
from . import dispatcher
22-
from . import feature
23-
from . import loop_state
24-
from . import measure
25-
from . import measure_record
26-
from . import relay_integration
27-
from . import search_policy
28-
from . import search_task
29-
from . import task_scheduler
30-
from . import utils
31-
from . import workload_registry
20+
from . import (
21+
compute_dag,
22+
dispatcher,
23+
feature,
24+
loop_state,
25+
measure,
26+
measure_record,
27+
relay_integration,
28+
search_policy,
29+
search_task,
30+
task_scheduler,
31+
utils,
32+
workload_registry,
33+
)
3234

3335
# Shortcut
34-
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
36+
from .compute_dag import (
37+
ComputeDAG,
38+
LayoutRewriteOption,
39+
get_shape_from_rewritten_layout,
40+
)
3541
from .cost_model import RandomModel, XGBModel
36-
from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample
42+
from .dispatcher import ApplyHistoryBest, ApplyHistoryBestOrSample, DispatchContext
3743
from .measure import (
38-
MeasureInput,
39-
MeasureResult,
4044
LocalBuilder,
45+
LocalRPCMeasureContext,
4146
LocalRunner,
47+
MeasureInput,
48+
MeasureResult,
4249
RPCRunner,
43-
LocalRPCMeasureContext,
4450
register_task_input_check_func,
4551
)
46-
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
52+
from .measure_record import (
53+
RecordReader,
54+
RecordToFile,
55+
load_best_record,
56+
load_records,
57+
save_records,
58+
)
4759
from .relay_integration import (
4860
extract_tasks,
61+
is_auto_scheduler_enabled,
4962
remove_index_check,
5063
rewrite_compute_body,
51-
is_auto_scheduler_enabled,
64+
rewrite_tensor_shape,
5265
)
53-
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
5466
from .search_policy import (
5567
EmptyPolicy,
56-
SketchPolicy,
57-
PreloadMeasuredStates,
5868
PreloadCustomSketchRule,
69+
PreloadMeasuredStates,
70+
SketchPolicy,
71+
)
72+
from .search_task import (
73+
HardwareParams,
74+
SearchTask,
75+
TuningOptions,
76+
auto_schedule,
77+
create_task,
5978
)
6079
from .task_scheduler import TaskScheduler
61-
from .workload_registry import register_workload, make_workload_key
80+
from .workload_registry import make_workload_key, register_workload

python/tvm/meta_schedule/default_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def postprocs() -> List[Postproc]:
262262
M.DisallowDynamicLoop(),
263263
M.RewriteParallelVectorizeUnroll(),
264264
M.RewriteReductionBlock(),
265+
M.RewriteLayout(),
265266
]
266267

267268
@staticmethod

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""The tvm.meta_schedule.postproc package."""
18-
from .postproc import Postproc, PyPostproc
1918
from .disallow_dynamic_loop import DisallowDynamicLoop
19+
from .postproc import Postproc, PyPostproc
2020
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
21+
from .rewrite_layout import RewriteLayout
2122
from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
2223
from .rewrite_reduction_block import RewriteReductionBlock
24+
from .rewrite_tensorize import RewriteTensorize
2325
from .rewrite_unbound_block import RewriteUnboundBlock
2426
from .verify_gpu_code import VerifyGPUCode
25-
from .rewrite_tensorize import RewriteTensorize
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that rewrites the layout of input tensor"""
18+
19+
from tvm._ffi.registry import register_object
20+
21+
from .. import _ffi_api
22+
from .postproc import Postproc
23+
24+
25+
@register_object("meta_schedule.RewriteLayout")
26+
class RewriteLayout(Postproc):
27+
"""A postprocessor that rewrites the layout of input tensor"""
28+
29+
def __init__(self) -> None:
30+
self.__init_handle_by_constructor__(
31+
_ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member
32+
)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*!
25+
* \brief Collect the block and index where the buffer is read.
26+
* \note The buffers are expected to be read by only one BufferLoad
27+
*/
28+
class BufferReadPosCollector : public StmtExprVisitor {
29+
public:
30+
explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
31+
for (const Buffer& buf : buffers) {
32+
buffers_.insert(buf.get());
33+
}
34+
}
35+
36+
const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
37+
return buffer_locs_;
38+
}
39+
40+
const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
41+
return buffer_index_maps_;
42+
}
43+
44+
private:
45+
void VisitStmt_(const ForNode* op) final {
46+
loop_stack_.push_back(GetRef<For>(op));
47+
StmtVisitor::VisitStmt_(op);
48+
loop_stack_.pop_back();
49+
}
50+
51+
void VisitStmt_(const BlockRealizeNode* op) final {
52+
BlockRealize outer_block_realize = GetRef<BlockRealize>(op);
53+
std::swap(outer_block_realize, cur_realize_);
54+
StmtVisitor::VisitStmt_(op);
55+
std::swap(cur_realize_, outer_block_realize);
56+
}
57+
58+
void VisitExpr_(const BufferLoadNode* op) final {
59+
const Buffer& buffer = op->buffer;
60+
if (buffers_.count(buffer.get())) {
61+
Map<Var, PrimExpr> subst_map;
62+
for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
63+
const Var& var = cur_realize_->block->iter_vars[i]->var;
64+
const PrimExpr& value = cur_realize_->iter_values[i];
65+
subst_map.Set(var, value);
66+
}
67+
Array<PrimExpr> subst_indices;
68+
for (const PrimExpr& e : op->indices) {
69+
subst_indices.push_back(Substitute(e, subst_map));
70+
}
71+
buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, //
72+
/*indices=*/subst_indices, //
73+
/*loops=*/loop_stack_, //
74+
/*predicate=*/cur_realize_->predicate, //
75+
/*analyzer=*/&analyzer_);
76+
int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
77+
ICHECK(buffer_index != -1);
78+
buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
79+
}
80+
}
81+
82+
static int GetReadBufferIndex(const Block& block, const Buffer& buffer) {
83+
for (size_t i = 0; i < block->reads.size(); i++) {
84+
if (block->reads[i]->buffer.same_as(buffer)) {
85+
return i;
86+
}
87+
}
88+
return -1;
89+
}
90+
91+
private:
92+
/*! \brief All interested buffer. */
93+
std::unordered_set<const BufferNode*> buffers_;
94+
/*! \brief The result mapping from buffer to its inner-most block and read index. */
95+
std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
96+
/*! \brief The result mapping from buffer to its IndexMap. */
97+
std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;
98+
99+
/*! \brief Loop stack for calculating IndexMap. */
100+
Array<For> loop_stack_;
101+
/*! \brief Arithmetic analyzer. */
102+
arith::Analyzer analyzer_;
103+
/*! \brief Current BlockRealize scope, used in recursive visit */
104+
BlockRealize cur_realize_;
105+
};
106+
107+
bool RewriteLayout(const Schedule& sch) {
108+
std::vector<std::pair<StmtSRef, String>> results;
109+
for (const auto& kv : sch->mod()->functions) {
110+
const GlobalVar& g_var = kv.first;
111+
const String& func_name = g_var->name_hint;
112+
const auto* prim_func = kv.second.as<PrimFuncNode>();
113+
// Only consider PrimFunc
114+
if (prim_func == nullptr) {
115+
continue;
116+
}
117+
// Only rewrite PrimFuncs with attr "layout_free_buffers"
118+
Array<Integer> layout_free_buffer_index =
119+
prim_func->GetAttr(attr::layout_free_buffers, Array<Integer>()).value();
120+
121+
Array<Buffer> layout_free_buffers;
122+
for (const Integer& index : layout_free_buffer_index) {
123+
const Var& param = prim_func->params[index->value];
124+
layout_free_buffers.push_back(prim_func->buffer_map.at(param));
125+
}
126+
// Collect Buffer read positions
127+
BufferReadPosCollector collector(layout_free_buffers);
128+
collector(prim_func->body);
129+
const auto& locations = collector.GetBufferLocations();
130+
const auto& index_maps = collector.GetBufferIndexMap();
131+
// Check all buffers are collected
132+
if (locations.size() != layout_free_buffers.size() ||
133+
index_maps.size() != layout_free_buffer_index.size()) {
134+
return false;
135+
}
136+
137+
for (const auto& kv : locations) {
138+
const Buffer& buffer = GetRef<Buffer>(kv.first);
139+
const Block& block = kv.second.first;
140+
int buffer_index = kv.second.second;
141+
142+
// Get IndexMap
143+
const Optional<IndexMap> index_map = index_maps.at(buffer.get());
144+
if (!index_map.defined()) {
145+
continue;
146+
}
147+
148+
// Apply schedule
149+
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
150+
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
151+
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value());
152+
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
153+
}
154+
}
155+
return true;
156+
}
157+
158+
} // namespace tir
159+
160+
namespace meta_schedule {
161+
/*! \brief Layout Rewrite. */
162+
class RewriteLayoutNode : public PostprocNode {
163+
public:
164+
// Inherited from PostprocNode
165+
void InitializeWithTuneContext(const TuneContext& context) final {}
166+
167+
// Inherited from PostprocNode
168+
bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }
169+
170+
static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
171+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
172+
};
173+
174+
Postproc Postproc::RewriteLayout() {
175+
auto n = make_object<RewriteLayoutNode>();
176+
return Postproc(n);
177+
}
178+
179+
TVM_REGISTER_NODE_TYPE(RewriteLayoutNode);
180+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout);
181+
182+
} // namespace meta_schedule
183+
} // namespace tvm

0 commit comments

Comments
 (0)