Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
# under the License.
"""The tvm.meta_schedule.postproc package."""
from .postproc import Postproc, PyPostproc
from .verify_gpu_code import VerifyGPUCode
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/verify_gpu_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that verifies if the GPU code is correct"""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.VerifyGPUCode")
class VerifyGPUCode(Postproc):
"""A postprocessor that verifies if the GPU code is correct"""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member
)
130 changes: 130 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/tir/transform.h>

#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief Extract attribute from a target. */
Integer Extract(const Target& target, const char* name) {
ICHECK(target.defined());
if (Optional<Integer> v = target->GetAttr<Integer>(name)) {
return v.value();
}
LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target";
throw;
}

/*! \brief Verify the correctness of the generated GPU code. */
class VerifyGPUCodeNode : public PostprocNode {
public:
Map<String, PrimExpr> target_constraints_{nullptr};

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
this->target_constraints_ = Map<String, PrimExpr>{
{"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")},
{"max_local_memory_per_block", Extract(target, "registers_per_block")},
{"max_threads_per_block", Extract(target, "max_threads_per_block")},
{"max_vthread", Integer(8)},
{"max_vector_bytes", Integer(16)}};
}

bool Verify(const IRModule& mod) const {
for (const auto& kv : mod->functions) {
if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
if (!tir::VerifyGPUCode(GetRef<tir::PrimFunc>(prim_func), this->target_constraints_)) {
return false;
}
}
}
return true;
}

bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
for (const auto& kv : mod->functions) {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
// Phase 1
// First three passes are not needed in TIR schedule.
// pass_list.push_back(tir::transform::InjectPrefetch());
// pass_list.push_back(tir::transform::TextureFlatten());
// pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());

// Phase 2
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());

// Convert Function to IRModule
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
} catch (const dmlc::Error& e) {
return false;
}
if (!Verify(lowered)) {
return false;
}
}
}
return true;
}

static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode";
TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode);
};

Postproc Postproc::VerifyGPUCode() {
ObjectPtr<VerifyGPUCodeNode> n = make_object<VerifyGPUCodeNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode);

} // namespace meta_schedule
} // namespace tvm
231 changes: 231 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

import sys
import pytest
import tvm
from tvm import tir
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.postproc import VerifyGPUCode
from tvm.script import tir as T
from tvm.target import Target


def _target() -> Target:
return Target("nvidia/geforce-rtx-3080")


def _create_context(mod, target) -> TuneContext:
ctx = TuneContext(
mod=mod,
target=target,
postprocs=[
VerifyGPUCode(),
],
task_name="test",
)
for rule in ctx.postprocs:
rule.initialize_with_tune_context(ctx)
return ctx


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
# fmt: off

@tvm.script.ir_module
class Conv2dCuda0:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "T.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
threadIdx_y = T.env_thread("threadIdx.y")
blockIdx_x = T.env_thread("blockIdx.x")
blockIdx_y = T.env_thread("blockIdx.y")
blockIdx_z = T.env_thread("blockIdx.z")
A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
T.launch_thread(threadIdx_x, 8)
for ff_c_init, nn_c_init in T.grid(8, 8):
T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
for rc_outer, ry, rx in T.grid(32, 3, 3):
for ax3_inner_outer in T.serial(0, 2):
T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
for rc_inner in T.serial(0, 8):
for ax3 in T.serial(0, 8):
T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
for ff_c, nn_c in T.grid(8, 8):
T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on


@tvm.script.ir_module
class Conv2dCuda1:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "T.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
threadIdx_y = T.env_thread("threadIdx.y")
blockIdx_x = T.env_thread("blockIdx.x")
blockIdx_y = T.env_thread("blockIdx.y")
blockIdx_z = T.env_thread("blockIdx.z")
A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([6400000], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
T.launch_thread(threadIdx_x, 8)
for ff_c_init, nn_c_init in T.grid(8, 8):
T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
for rc_outer, ry, rx in T.grid(32, 3, 3):
for ax3_inner_outer in T.serial(0, 2):
T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
for rc_inner in T.serial(0, 8):
for ax3 in T.serial(0, 8):
T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
for ff_c, nn_c in T.grid(8, 8):
T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on


@tvm.script.ir_module
class Conv2dCuda2:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "T.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
threadIdx_y = T.env_thread("threadIdx.y")
blockIdx_x = T.env_thread("blockIdx.x")
blockIdx_y = T.env_thread("blockIdx.y")
blockIdx_z = T.env_thread("blockIdx.z")
A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512000], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
T.launch_thread(threadIdx_x, 8)
for ff_c_init, nn_c_init in T.grid(8, 8):
T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
for rc_outer, ry, rx in T.grid(32, 3, 3):
for ax3_inner_outer in T.serial(0, 2):
T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
for rc_inner in T.serial(0, 8):
for ax3 in T.serial(0, 8):
T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
for ff_c, nn_c in T.grid(8, 8):
T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on


@tvm.script.ir_module
class Conv2dCuda3:
@T.prim_func
def main(a: T.handle, b: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "T.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
threadIdx_y = T.env_thread("threadIdx.y")
blockIdx_x = T.env_thread("blockIdx.x")
blockIdx_y = T.env_thread("blockIdx.y")
blockIdx_z = T.env_thread("blockIdx.z")
A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
# body
T.launch_thread(blockIdx_z, 196)
B_local = T.allocate([64], "float32", "local")
Apad_shared = T.allocate([512], "float32", "shared")
Apad_shared_local = T.allocate([8], "float32", "local")
T.launch_thread(blockIdx_y, 8)
T.launch_thread(blockIdx_x, 4)
T.launch_thread(threadIdx_y, 8)
T.launch_thread(threadIdx_x, 800000)
for ff_c_init, nn_c_init in T.grid(8, 8):
T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
for rc_outer, ry, rx in T.grid(32, 3, 3):
for ax3_inner_outer in T.serial(0, 2):
T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
for rc_inner in T.serial(0, 8):
for ax3 in T.serial(0, 8):
T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
for ff_c, nn_c in T.grid(8, 8):
T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant


def test_postproc_verify_gpu_0():
mod = Conv2dCuda0
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_1():
mod = Conv2dCuda1
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_2():
mod = Conv2dCuda2
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_3():
mod = Conv2dCuda3
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))