diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 6ee052ecba68..2e8fa1e777d6 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,3 +16,4 @@ # under the License. """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py new file mode 100644 index 000000000000..501e4423196c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that verifies if the GPU code is correct""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.VerifyGPUCode") +class VerifyGPUCode(Postproc): + """A postprocessor that verifies if the GPU code is correct""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc new file mode 100644 index 000000000000..edf13e36bef4 --- /dev/null +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Extract attribute from a target. */ +Integer Extract(const Target& target, const char* name) { + ICHECK(target.defined()); + if (Optional v = target->GetAttr(name)) { + return v.value(); + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +/*! \brief Verify the correctness of the generated GPU code. */ +class VerifyGPUCodeNode : public PostprocNode { + public: + Map target_constraints_{nullptr}; + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->target_constraints_ = Map{ + {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, + {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_threads_per_block", Extract(target, "max_threads_per_block")}, + {"max_vthread", Integer(8)}, + {"max_vector_bytes", Integer(16)}}; + } + + bool Verify(const IRModule& mod) const { + for (const auto& kv : mod->functions) { + if (const auto* prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + return false; + } + } + } + return true; + } + + bool Apply(const tir::Schedule& sch) final { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + IRModule lowered{nullptr}; + try { + auto pass_list = Array(); + // Phase 1 + // First three passes are not needed in TIR schedule. + // pass_list.push_back(tir::transform::InjectPrefetch()); + // pass_list.push_back(tir::transform::TextureFlatten()); + // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerCrossThreadReduction()); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::NarrowDataType(32)); + pass_list.push_back(tir::transform::Simplify()); + + // Phase 2 + pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + + // Convert Function to IRModule + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", + runtime::String(g_var->name_hint)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); + } catch (const dmlc::Error& e) { + return false; + } + if (!Verify(lowered)) { + return false; + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; + TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); +}; + +Postproc Postproc::VerifyGPUCode() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py new file mode 100644 index 000000000000..bebfec6122b3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import sys +import pytest +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import VerifyGPUCode +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + VerifyGPUCode(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Conv2dCuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda1: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([6400000], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda2: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512000], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda3: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 800000) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant + + +def test_postproc_verify_gpu_0(): + mod = Conv2dCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_1(): + mod = Conv2dCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_2(): + mod = Conv2dCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_3(): + mod = Conv2dCuda3 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))