Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,17 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);

/*!
* \brief Creates a postprocessor that verifies if the GPU code is correct
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyGPUCode();
/*!
* \brief Creates a postprocessor that rewrites the layout of input tensor
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
* \return The postprocessor created
*/
TVM_DLL static Postproc RewriteLayout();
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

Expand Down
65 changes: 42 additions & 23 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,64 @@
# pylint: disable=unused-import, redefined-builtin
""" Namespace for TVM Auto-scheduler. """

from . import compute_dag
from . import dispatcher
from . import feature
from . import loop_state
from . import measure
from . import measure_record
from . import relay_integration
from . import search_policy
from . import search_task
from . import task_scheduler
from . import utils
from . import workload_registry
from . import (
compute_dag,
dispatcher,
feature,
loop_state,
measure,
measure_record,
relay_integration,
search_policy,
search_task,
task_scheduler,
utils,
workload_registry,
)

# Shortcut
from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout
from .compute_dag import (
ComputeDAG,
LayoutRewriteOption,
get_shape_from_rewritten_layout,
)
from .cost_model import RandomModel, XGBModel
from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample
from .dispatcher import ApplyHistoryBest, ApplyHistoryBestOrSample, DispatchContext
from .measure import (
MeasureInput,
MeasureResult,
LocalBuilder,
LocalRPCMeasureContext,
LocalRunner,
MeasureInput,
MeasureResult,
RPCRunner,
LocalRPCMeasureContext,
register_task_input_check_func,
)
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .measure_record import (
RecordReader,
RecordToFile,
load_best_record,
load_records,
save_records,
)
from .relay_integration import (
extract_tasks,
is_auto_scheduler_enabled,
remove_index_check,
rewrite_compute_body,
is_auto_scheduler_enabled,
rewrite_tensor_shape,
)
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .search_policy import (
EmptyPolicy,
SketchPolicy,
PreloadMeasuredStates,
PreloadCustomSketchRule,
PreloadMeasuredStates,
SketchPolicy,
)
from .search_task import (
HardwareParams,
SearchTask,
TuningOptions,
auto_schedule,
create_task,
)
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
from .workload_registry import make_workload_key, register_workload
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def postprocs() -> List[Postproc]:
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.RewriteLayout(),
]

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .disallow_dynamic_loop import DisallowDynamicLoop
from .postproc import Postproc, PyPostproc
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
from .rewrite_layout import RewriteLayout
from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_tensorize import RewriteTensorize
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
from .rewrite_tensorize import RewriteTensorize
32 changes: 32 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that rewrites the layout of input tensor"""

from tvm._ffi.registry import register_object

from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteLayout")
class RewriteLayout(Postproc):
"""A postprocessor that rewrites the layout of input tensor"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member
)
183 changes: 183 additions & 0 deletions src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Collect the block and index where the buffer is read.
* \note The buffers are expected to be read by only one BufferLoad
*/
class BufferReadPosCollector : public StmtExprVisitor {
public:
explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
for (const Buffer& buf : buffers) {
buffers_.insert(buf.get());
}
}

const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
return buffer_locs_;
}

const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
return buffer_index_maps_;
}

private:
void VisitStmt_(const ForNode* op) final {
loop_stack_.push_back(GetRef<For>(op));
StmtVisitor::VisitStmt_(op);
loop_stack_.pop_back();
}

void VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize outer_block_realize = GetRef<BlockRealize>(op);
std::swap(outer_block_realize, cur_realize_);
StmtVisitor::VisitStmt_(op);
std::swap(cur_realize_, outer_block_realize);
}

void VisitExpr_(const BufferLoadNode* op) final {
const Buffer& buffer = op->buffer;
if (buffers_.count(buffer.get())) {
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
const Var& var = cur_realize_->block->iter_vars[i]->var;
const PrimExpr& value = cur_realize_->iter_values[i];
subst_map.Set(var, value);
}
Array<PrimExpr> subst_indices;
for (const PrimExpr& e : op->indices) {
subst_indices.push_back(Substitute(e, subst_map));
}
buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, //
/*indices=*/subst_indices, //
/*loops=*/loop_stack_, //
/*predicate=*/cur_realize_->predicate, //
/*analyzer=*/&analyzer_);
int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
ICHECK(buffer_index != -1);
buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
}
}

static int GetReadBufferIndex(const Block& block, const Buffer& buffer) {
for (size_t i = 0; i < block->reads.size(); i++) {
if (block->reads[i]->buffer.same_as(buffer)) {
return i;
}
}
return -1;
}

private:
/*! \brief All interested buffer. */
std::unordered_set<const BufferNode*> buffers_;
/*! \brief The result mapping from buffer to its inner-most block and read index. */
std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
/*! \brief The result mapping from buffer to its IndexMap. */
std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;

/*! \brief Loop stack for calculating IndexMap. */
Array<For> loop_stack_;
/*! \brief Arithmetic analyzer. */
arith::Analyzer analyzer_;
/*! \brief Current BlockRealize scope, used in recursive visit */
BlockRealize cur_realize_;
};

bool RewriteLayout(const Schedule& sch) {
std::vector<std::pair<StmtSRef, String>> results;
for (const auto& kv : sch->mod()->functions) {
const GlobalVar& g_var = kv.first;
const String& func_name = g_var->name_hint;
const auto* prim_func = kv.second.as<PrimFuncNode>();
// Only consider PrimFunc
if (prim_func == nullptr) {
continue;
}
// Only rewrite PrimFuncs with attr "layout_free_buffers"
Array<Integer> layout_free_buffer_index =
prim_func->GetAttr(attr::layout_free_buffers, Array<Integer>()).value();

Array<Buffer> layout_free_buffers;
for (const Integer& index : layout_free_buffer_index) {
const Var& param = prim_func->params[index->value];
layout_free_buffers.push_back(prim_func->buffer_map.at(param));
}
// Collect Buffer read positions
BufferReadPosCollector collector(layout_free_buffers);
collector(prim_func->body);
const auto& locations = collector.GetBufferLocations();
const auto& index_maps = collector.GetBufferIndexMap();
// Check all buffers are collected
if (locations.size() != layout_free_buffers.size() ||
index_maps.size() != layout_free_buffer_index.size()) {
return false;
}

for (const auto& kv : locations) {
const Buffer& buffer = GetRef<Buffer>(kv.first);
const Block& block = kv.second.first;
int buffer_index = kv.second.second;

// Get IndexMap
const Optional<IndexMap> index_map = index_maps.at(buffer.get());
if (!index_map.defined()) {
continue;
}

// Apply schedule
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value());
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
return true;
}

} // namespace tir

namespace meta_schedule {
/*! \brief Layout Rewrite. */
class RewriteLayoutNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); }

static constexpr const char* _type_key = "meta_schedule.RewriteLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteLayoutNode, PostprocNode);
};

Postproc Postproc::RewriteLayout() {
auto n = make_object<RewriteLayoutNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteLayoutNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout);

} // namespace meta_schedule
} // namespace tvm
Loading