diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index f9c9f309ab3..da52ce656b8 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -221,6 +221,13 @@ Model performance tunable variables change the compilation behavior of a model. | Default: No tuning is done for composable kernels. + * - | ``MIGRAPHX_REWRITE_LRN`` + | Turns on LRN-to-pooling lowering in the rewrite_pooling pass. + + - | ``1``: Turns on LRN-to-pooling lowering. + | ``0``: Returns to default behavior. + + | Default: LRN-to-pooling lowering is turned off. Matching ********** diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index ebef9834786..dd69bb7ca93 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,7 @@ struct module; */ struct MIGRAPHX_EXPORT rewrite_pooling { + bool rewrite_lrn = false; std::string name() const { return "rewrite_pooling"; } void apply(module& m) const; }; diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6100043e802..6edf6689397 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,8 @@ #include #include +#include +#include #include namespace migraphx { @@ -55,6 +57,81 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + const auto& lens = xshape.lens(); + + if(lens.size() != 4 or size <= 0 or size > lens[1]) + { + return; + } + + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + std::vector perm = {0, 2, 3, 1}; + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + auto transposed_shape = transpose1->get_shape(); + const auto& transposed_lens = transposed_shape.lens(); + + int64_t channel_dim = lens[1]; + std::vector calculated_pads(2); + calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); + + auto avg = m.insert_instruction( + ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto avg_shape = avg->get_shape(); + const auto& avg_lens = avg_shape.lens(); + + if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) + { + return; + } + + std::vector inv_perm = {0, 3, 1, 2}; + auto transpose2 = + m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + + auto final_shape = transpose2->get_shape(); + const auto& final_lens = final_shape.lens(); + + if(final_lens != lens) + return; + + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); +} + static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { // TODO remove this when MIOpen supports dilated pooling @@ -143,10 +220,16 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(ins->name() != "pooling") - continue; if(ins->inputs().empty()) continue; + if(rewrite_lrn and ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + if(ins->name() != "pooling") + continue; + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); bool same_kernel_as_shape = std::equal( diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 85b3698c8ab..5844c934259 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -84,6 +84,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif @@ -203,7 +204,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{}, + rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}, dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1ccafed179e..5377dc99715 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -34,6 +34,9 @@ #include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN); static void opt_pooling(migraphx::module& m) { migraphx::rewrite_pooling rp; @@ -309,6 +312,77 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(test_lower_lrn_to_pooling) +{ + migraphx::module m1; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + } + // Apply the pass directly when the flag enabled + migraphx::rewrite_pooling rp{.rewrite_lrn = true}; + migraphx::dead_code_elimination dce; + rp.apply(m1); + dce.apply(m1); + + migraphx::module m2; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), + x2); + + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; + int64_t pad_right = lrn_size - 1 - pad_left; + std::vector expected_pads = {pad_left, pad_right}; + + auto avg = m2.add_instruction( + migraphx::make_op( + "pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, lrn_size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), + avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + } + + EXPECT(m1 == m2); +} + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp new file mode 100644 index 00000000000..2eae1a202e5 --- /dev/null +++ b/test/verify/test_lrn.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_lrn : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter( + "x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); + mm->add_instruction( + migraphx::make_op( + "lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), + x); + return p; + } +}; + +template struct test_lrn<32, 6>; +template struct test_lrn<32, 5>; +template struct test_lrn<31, 8>; +template struct test_lrn<31, 5>; diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index feec2ab21c6..d201bd77317 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal