From 165869d7510236abad875a3b76fa3f2bf4696c05 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 10 Sep 2025 19:24:31 +0000 Subject: [PATCH 01/64] Adding lower_lrn_to_pooling function --- src/rewrite_pooling.cpp | 87 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6100043e802..09d8632b4bf 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include @@ -55,6 +56,86 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + // Extract params from lrn operator + auto v = ins->get_operator().to_value(); + // ONNX-style names: alpha, beta, bias(k), size; optional axis (default 1) + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + int axis = 1; // LRN default axis + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + auto lens = xshape.lens(); // e.g., NCHW + const int64_t rank = static_cast(lens.size()); + int64_t caxis = axis < 0 ? axis + rank : axis; + if(rank < 2 || caxis < 0 || caxis >= rank) return; // conservative guard + if(size <= 0 || (size % 2) == 0) return; // LRN requires odd > 0 + + const int half = size / 2; + + // 1) x^2 + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + // 2) Move channel axis to the last dim + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[static_cast(caxis)], perm.back()); + auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + auto moved_lens = moved->get_shape().lens(); // [D0, D1, ..., C] + + // 3) Reshape to (B,1,1,C) so we can pool over "width"=C + std::size_t B = 1; + for(std::size_t i = 0; i + 1 < moved_lens.size(); ++i) B *= moved_lens[i]; + const int64_t C = static_cast(moved_lens.back()); + auto pooled_in = m.insert_instruction( + ins, + make_op("reshape", {{"dims", std::vector{static_cast(B), 1, 1, C}}}), + moved); + + // 4) Average pool with symmetric padding on width (include-pad semantics) + // kernel: (1, size), stride: (1,1), padding: (0, half) + auto avg = m.insert_instruction( + ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, half}}, + {"count_include_pad", true}}), + pooled_in); + + // 5) Reshape back to original "moved" shape + auto moved_shape_back = + std::vector(moved_lens.begin(), moved_lens.end()); // [D0,...,C] + auto avg_moved = m.insert_instruction( + ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); + + // 6) Transpose back to original layout (inverse perm) + auto invp = invert_permutation(perm); + auto avg_ch = m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); + + // 7) Build denominator: den = k + alpha * avg_ch + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg_ch); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + + // 8) y = x / den^beta + auto b_lit = m.add_literal(beta); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), ins->inputs().front(), denpow); + + // Replace lrn with the new subgraph + m.replace_instruction(ins, y); +} + static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { // TODO remove this when MIOpen supports dilated pooling @@ -147,6 +228,12 @@ void rewrite_pooling::apply(module& m) const continue; if(ins->inputs().empty()) continue; + if(ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); bool same_kernel_as_shape = std::equal( From 65d5b7026aa2e5169b13c00631225c9df91debf5 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 10 Sep 2025 19:41:38 +0000 Subject: [PATCH 02/64] Adding lower_lrn_to_pooling function --- src/rewrite_pooling.cpp | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 09d8632b4bf..349a35cacf8 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -58,36 +58,32 @@ static void replace_with_reduce(module& m, instruction_ref ins) static void lower_lrn_to_pooling(module& m, instruction_ref ins) { - // Extract params from lrn operator auto v = ins->get_operator().to_value(); - // ONNX-style names: alpha, beta, bias(k), size; optional axis (default 1) + float alpha = v.at("alpha").to(); float beta = v.at("beta").to(); float k = v.at("bias").to(); int size = v.at("size").to(); - int axis = 1; // LRN default axis + int axis = 1; auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); auto lens = xshape.lens(); // e.g., NCHW const int64_t rank = static_cast(lens.size()); int64_t caxis = axis < 0 ? axis + rank : axis; - if(rank < 2 || caxis < 0 || caxis >= rank) return; // conservative guard - if(size <= 0 || (size % 2) == 0) return; // LRN requires odd > 0 + if(rank < 2 || caxis < 0 || caxis >= rank) return; + if(size <= 0 || (size % 2) == 0) return; const int half = size / 2; - // 1) x^2 auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); - // 2) Move channel axis to the last dim std::vector perm(rank); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[static_cast(caxis)], perm.back()); auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); - auto moved_lens = moved->get_shape().lens(); // [D0, D1, ..., C] + auto moved_lens = moved->get_shape().lens(); - // 3) Reshape to (B,1,1,C) so we can pool over "width"=C std::size_t B = 1; for(std::size_t i = 0; i + 1 < moved_lens.size(); ++i) B *= moved_lens[i]; const int64_t C = static_cast(moved_lens.back()); @@ -96,8 +92,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) make_op("reshape", {{"dims", std::vector{static_cast(B), 1, 1, C}}}), moved); - // 4) Average pool with symmetric padding on width (include-pad semantics) - // kernel: (1, size), stride: (1,1), padding: (0, half) auto avg = m.insert_instruction( ins, make_op("pooling", @@ -108,17 +102,15 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) {"count_include_pad", true}}), pooled_in); - // 5) Reshape back to original "moved" shape auto moved_shape_back = - std::vector(moved_lens.begin(), moved_lens.end()); // [D0,...,C] + std::vector(moved_lens.begin(), moved_lens.end()); auto avg_moved = m.insert_instruction( ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); - // 6) Transpose back to original layout (inverse perm) + auto invp = invert_permutation(perm); auto avg_ch = m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); - // 7) Build denominator: den = k + alpha * avg_ch auto k_lit = m.add_literal(k); auto a_lit = m.add_literal(alpha); auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); @@ -126,13 +118,11 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg_ch); auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - // 8) y = x / den^beta auto b_lit = m.add_literal(beta); auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); auto y = m.insert_instruction(ins, make_op("div"), ins->inputs().front(), denpow); - // Replace lrn with the new subgraph m.replace_instruction(ins, y); } @@ -233,7 +223,7 @@ void rewrite_pooling::apply(module& m) const lower_lrn_to_pooling(m, ins); continue; } - + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); bool same_kernel_as_shape = std::equal( From 83651ec6aa30d33a901ad835c6acc203900c1a49 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 10 Sep 2025 19:48:22 +0000 Subject: [PATCH 03/64] Adding invert permutation header --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 349a35cacf8..b3ae2f1d2ee 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -30,7 +30,7 @@ #include #include #include - +#include #include namespace migraphx { From 75d63708e4198a82bf595684ce334e5ffeac1be5 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 19 Sep 2025 20:20:25 +0000 Subject: [PATCH 04/64] changes in apply --- src/rewrite_pooling.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index b3ae2f1d2ee..0f74f1a84d8 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -29,6 +29,7 @@ #include #include #include + #include #include #include @@ -126,6 +127,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) m.replace_instruction(ins, y); } + + static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { // TODO remove this when MIOpen supports dilated pooling @@ -214,7 +217,7 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(ins->name() != "pooling") + /*if(ins->name() != "pooling") continue; if(ins->inputs().empty()) continue; @@ -223,6 +226,20 @@ void rewrite_pooling::apply(module& m) const lower_lrn_to_pooling(m, ins); continue; } + */ + + if(ins->inputs().empty()) + continue; + + if(ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + + if(ins->name() != "pooling") + continue; + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); From 6d94d79c040b09c6a0bc1886c850037e4cb22ff8 Mon Sep 17 00:00:00 2001 From: Aarushi Jain <142941703+aarushjain29@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:20:59 -0500 Subject: [PATCH 05/64] Update src/rewrite_pooling.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/rewrite_pooling.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 0f74f1a84d8..5bbef851fce 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -221,10 +221,10 @@ void rewrite_pooling::apply(module& m) const continue; if(ins->inputs().empty()) continue; - if(ins->name() == "lrn") - { - lower_lrn_to_pooling(m, ins); - continue; + if(ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; } */ From 0b717ab69c3149d2e9270ca8ce8d4a50bd764c88 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 22 Sep 2025 22:31:55 +0000 Subject: [PATCH 06/64] test case added --- test/rewrite_pooling_test.cpp | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1ccafed179e..bfd43a97c4d 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -309,6 +309,44 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(lower_lrn_to_pooling_test) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m; + auto x = m.add_parameter("x", s); + auto lrn = m.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + m.add_return({lrn}); + + int lrn_count_before = 0; + int pooling_count_before = 0; + for(auto ins : migraphx::iterator_for(m)) { + if(ins->name() == "lrn") lrn_count_before++; + if(ins->name() == "pooling") pooling_count_before++; + } + + opt_pooling(m); + + int lrn_count_after = 0; + int pooling_count_after = 0; + int mul_count = 0; + int div_count = 0; + for(auto ins : migraphx::iterator_for(m)) { + if(ins->name() == "lrn") lrn_count_after++; + if(ins->name() == "pooling") pooling_count_after++; + if(ins->name() == "mul") mul_count++; + if(ins->name() == "div") div_count++; + } + + EXPECT(lrn_count_before == 1); + EXPECT(lrn_count_after == 0); + EXPECT(pooling_count_after > pooling_count_before); + EXPECT(mul_count > 0); + EXPECT(div_count > 0); +} + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From bec9dac44b3b8aefc4dd81aeb116a04e48a0b700 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 23 Sep 2025 01:46:33 +0000 Subject: [PATCH 07/64] test case added --- test/rewrite_pooling_test.cpp | 62 +++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index bfd43a97c4d..b09ed4a1000 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -34,6 +34,8 @@ #include +#include + static void opt_pooling(migraphx::module& m) { migraphx::rewrite_pooling rp; @@ -313,38 +315,42 @@ TEST_CASE(lower_lrn_to_pooling_test) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - migraphx::module m; - auto x = m.add_parameter("x", s); - auto lrn = m.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - m.add_return({lrn}); + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 1.0f); + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto input = mm->add_parameter("x", s); + auto lrn = mm->add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + input); + mm->add_return({lrn}); + p1.compile(migraphx::make_target("ref")); + } - int lrn_count_before = 0; - int pooling_count_before = 0; - for(auto ins : migraphx::iterator_for(m)) { - if(ins->name() == "lrn") lrn_count_before++; - if(ins->name() == "pooling") pooling_count_before++; + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto input = mm->add_parameter("x", s); + auto lrn = mm->add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + input); + mm->add_return({lrn}); + + opt_pooling(*mm); + p2.compile(migraphx::make_target("ref")); } - opt_pooling(m); + migraphx::parameter_map params; + params["x"] = migraphx::argument(s, data.data()); + + auto result1 = p1.eval(params).back(); + auto result2 = p2.eval(params).back(); - int lrn_count_after = 0; - int pooling_count_after = 0; - int mul_count = 0; - int div_count = 0; - for(auto ins : migraphx::iterator_for(m)) { - if(ins->name() == "lrn") lrn_count_after++; - if(ins->name() == "pooling") pooling_count_after++; - if(ins->name() == "mul") mul_count++; - if(ins->name() == "div") div_count++; - } - - EXPECT(lrn_count_before == 1); - EXPECT(lrn_count_after == 0); - EXPECT(pooling_count_after > pooling_count_before); - EXPECT(mul_count > 0); - EXPECT(div_count > 0); + visit_all(result1, result2)([&](auto r1, auto r2) { + EXPECT(migraphx::verify::verify_rms_range(r1, r2)); + }); } TEST_CASE(rewrite_avgpool_rank3_dil_test) From ea3901940e5156e0e38241dcc1445cfec4cadc65 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 23 Sep 2025 02:13:56 +0000 Subject: [PATCH 08/64] more test case --- test/rewrite_pooling_test.cpp | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index b09ed4a1000..ae54759817b 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -311,7 +311,25 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } -TEST_CASE(lower_lrn_to_pooling_test) +TEST_CASE(lower_lrn_to_pooling_test1) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto lrn = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + m1.add_return({lrn}); + } + + migraphx::module m2 = m1; + + opt_pooling(m2); +} + +TEST_CASE(lower_lrn_to_pooling_test2) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; From 69efb32e4ce05c9fd626093cf0d52905805a53a0 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 23 Sep 2025 02:16:56 +0000 Subject: [PATCH 09/64] remove comment --- src/rewrite_pooling.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 5bbef851fce..6788c708b35 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -217,17 +217,6 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - /*if(ins->name() != "pooling") - continue; - if(ins->inputs().empty()) - continue; - if(ins->name() == "lrn") - { - lower_lrn_to_pooling(m, ins); - continue; - } - */ - if(ins->inputs().empty()) continue; From fb8b708f9a6f1502e5926bcaca633d22907901b2 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 23 Sep 2025 02:18:52 +0000 Subject: [PATCH 10/64] remove spaces --- src/rewrite_pooling.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6788c708b35..a8a151790e3 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -217,18 +217,15 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(ins->inputs().empty()) + if(ins->inputs().empty()) continue; - - if(ins->name() == "lrn") - { - lower_lrn_to_pooling(m, ins); - continue; - } - - if(ins->name() != "pooling") - continue; - + if(ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + if(ins->name() != "pooling") + continue; auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); From 6dae73756ed1c2f9784b7fb8544d07d1ddb7b1bf Mon Sep 17 00:00:00 2001 From: Aarushi Jain <142941703+aarushjain29@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:20:00 -0500 Subject: [PATCH 11/64] Update test/rewrite_pooling_test.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/rewrite_pooling_test.cpp | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index ae54759817b..89bad24e843 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -311,21 +311,20 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } -TEST_CASE(lower_lrn_to_pooling_test1) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - - migraphx::module m1; - { - auto x = m1.add_parameter("x", s); - auto lrn = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - m1.add_return({lrn}); - } - - migraphx::module m2 = m1; - +TEST_CASE(lower_lrn_to_pooling_test1) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto lrn = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + m1.add_return({lrn}); + } + + migraphx::module m2 = m1; opt_pooling(m2); } From 09ef691ae919e8748fe683fdd13d3a219d751488 Mon Sep 17 00:00:00 2001 From: Aarushi Jain <142941703+aarushjain29@users.noreply.github.com> Date: Mon, 22 Sep 2025 21:20:12 -0500 Subject: [PATCH 12/64] Update test/rewrite_pooling_test.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/rewrite_pooling_test.cpp | 78 +++++++++++++++++------------------ 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 89bad24e843..a3b7ffa1dbb 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -328,46 +328,46 @@ TEST_CASE(lower_lrn_to_pooling_test1) opt_pooling(m2); } -TEST_CASE(lower_lrn_to_pooling_test2) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - - std::vector data(s.elements()); - std::iota(data.begin(), data.end(), 1.0f); - - migraphx::program p1; - { - auto* mm = p1.get_main_module(); - auto input = mm->add_parameter("x", s); - auto lrn = mm->add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - input); - mm->add_return({lrn}); +TEST_CASE(lower_lrn_to_pooling_test2) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 1.0f); + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto input = mm->add_parameter("x", s); + auto lrn = mm->add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + input); + mm->add_return({lrn}); p1.compile(migraphx::make_target("ref")); - } - - migraphx::program p2; - { - auto* mm = p2.get_main_module(); - auto input = mm->add_parameter("x", s); - auto lrn = mm->add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - input); - mm->add_return({lrn}); - - opt_pooling(*mm); - p2.compile(migraphx::make_target("ref")); - } - - migraphx::parameter_map params; - params["x"] = migraphx::argument(s, data.data()); - - auto result1 = p1.eval(params).back(); - auto result2 = p2.eval(params).back(); - - visit_all(result1, result2)([&](auto r1, auto r2) { - EXPECT(migraphx::verify::verify_rms_range(r1, r2)); - }); + } + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto input = mm->add_parameter("x", s); + auto lrn = mm->add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + input); + mm->add_return({lrn}); + + opt_pooling(*mm); + p2.compile(migraphx::make_target("ref")); + } + + migraphx::parameter_map params; + params["x"] = migraphx::argument(s, data.data()); + + auto result1 = p1.eval(params).back(); + auto result2 = p2.eval(params).back(); + + visit_all(result1, result2)([&](auto r1, auto r2) { + EXPECT(migraphx::verify::verify_rms_range(r1, r2)); + }); } TEST_CASE(rewrite_avgpool_rank3_dil_test) From 6ac18ee42eef95953a150d09ccbd20e1630ff3b7 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 02:16:10 +0000 Subject: [PATCH 13/64] tidy errors --- src/rewrite_pooling.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index a8a151790e3..edb11483938 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -72,8 +72,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto lens = xshape.lens(); // e.g., NCHW const int64_t rank = static_cast(lens.size()); int64_t caxis = axis < 0 ? axis + rank : axis; - if(rank < 2 || caxis < 0 || caxis >= rank) return; - if(size <= 0 || (size % 2) == 0) return; + if(rank < 2 or caxis < 0 or caxis >= rank) return; + if(size <= 0 or (size % 2) == 0) return; const int half = size / 2; @@ -85,12 +85,12 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto moved_lens = moved->get_shape().lens(); - std::size_t B = 1; - for(std::size_t i = 0; i + 1 < moved_lens.size(); ++i) B *= moved_lens[i]; - const int64_t C = static_cast(moved_lens.back()); + std::size_t b = 1; + for(std::size_t i = 0; i + 1 < moved_lens.size(); ++i) b *= moved_lens[i]; + const int64_t c = static_cast(moved_lens.back()); auto pooled_in = m.insert_instruction( ins, - make_op("reshape", {{"dims", std::vector{static_cast(B), 1, 1, C}}}), + make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), moved); auto avg = m.insert_instruction( From 4780e5ef41f648fa98b01dc6ef606ef765d36092 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 02:19:11 +0000 Subject: [PATCH 14/64] cpp errors --- src/rewrite_pooling.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index edb11483938..c4a4edc5b7d 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -71,8 +71,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); auto lens = xshape.lens(); // e.g., NCHW const int64_t rank = static_cast(lens.size()); - int64_t caxis = axis < 0 ? axis + rank : axis; - if(rank < 2 or caxis < 0 or caxis >= rank) return; + int64_t caxis = axis < 0 ? axis + rank : axis; + if(rank < 2 or caxis >= rank) return; if(size <= 0 or (size % 2) == 0) return; const int half = size / 2; From 3858f86c58b338eb51fdfe9473037e6f0665c46f Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 02:29:15 +0000 Subject: [PATCH 15/64] tidy errors --- src/rewrite_pooling.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index c4a4edc5b7d..8c688ddd793 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -65,7 +65,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) float beta = v.at("beta").to(); float k = v.at("bias").to(); int size = v.at("size").to(); - int axis = 1; + const unsigned int axis = 1; auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); @@ -75,8 +75,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) if(rank < 2 or caxis >= rank) return; if(size <= 0 or (size % 2) == 0) return; - const int half = size / 2; - auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); std::vector perm(rank); @@ -99,7 +97,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) {{"mode", op::pooling_mode::average}, {"lengths", std::vector{1, size}}, {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, half}}, + {"padding", std::vector{0, size/2}}, {"count_include_pad", true}}), pooled_in); From 2c1b76f9c26f9389b1b793542efbb6341eda8ff8 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 02:45:35 +0000 Subject: [PATCH 16/64] tidy errors --- src/rewrite_pooling.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 8c688ddd793..07de6f11583 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -83,8 +83,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto moved_lens = moved->get_shape().lens(); - std::size_t b = 1; - for(std::size_t i = 0; i + 1 < moved_lens.size(); ++i) b *= moved_lens[i]; + auto b = std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); const int64_t c = static_cast(moved_lens.back()); auto pooled_in = m.insert_instruction( ins, From 07a273076df502ac70bd10847247ccad94bccdf5 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 14:41:35 +0000 Subject: [PATCH 17/64] test case for comparing two models --- test/rewrite_pooling_test.cpp | 64 +++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index a3b7ffa1dbb..7946d6a45a7 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -311,6 +311,64 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(lower_lrn_to_pooling) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto lrn = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + m1.add_return({lrn}); + } + opt_pooling(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + + auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared); + auto reshape1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1); + auto pooling = m2.add_instruction( + migraphx::make_op("pooling", { + {"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, 5}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, 2}}, + {"count_include_pad", true} + }), reshape1); + auto reshape2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2); + + auto beta_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.75}}); + auto alpha_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.0001}}); + auto bias_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {1.0}}); + + auto bias_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); + auto alpha_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); + auto beta_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); + auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); + auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); + auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); + + m2.add_return({result}); + } + + EXPECT(m1.sort() == m2.sort()); +} +/* TEST_CASE(lower_lrn_to_pooling_test1) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; @@ -335,7 +393,7 @@ TEST_CASE(lower_lrn_to_pooling_test2) std::vector data(s.elements()); std::iota(data.begin(), data.end(), 1.0f); - migraphx::program p1; + migraphx::module p1; { auto* mm = p1.get_main_module(); auto input = mm->add_parameter("x", s); @@ -346,7 +404,7 @@ TEST_CASE(lower_lrn_to_pooling_test2) p1.compile(migraphx::make_target("ref")); } - migraphx::program p2; + migraphx::module p2; { auto* mm = p2.get_main_module(); auto input = mm->add_parameter("x", s); @@ -369,7 +427,7 @@ TEST_CASE(lower_lrn_to_pooling_test2) EXPECT(migraphx::verify::verify_rms_range(r1, r2)); }); } - +*/ TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From c2dc92b3d17e36b90400f8212a5f60d591893e98 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 14:56:10 +0000 Subject: [PATCH 18/64] test case for comparing two models --- test/rewrite_pooling_test.cpp | 58 ----------------------------------- 1 file changed, 58 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 7946d6a45a7..1025dcb7db1 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -368,66 +368,8 @@ TEST_CASE(lower_lrn_to_pooling) EXPECT(m1.sort() == m2.sort()); } -/* -TEST_CASE(lower_lrn_to_pooling_test1) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - migraphx::module m1; - { - auto x = m1.add_parameter("x", s); - auto lrn = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - m1.add_return({lrn}); - } - - migraphx::module m2 = m1; - opt_pooling(m2); -} -TEST_CASE(lower_lrn_to_pooling_test2) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - - std::vector data(s.elements()); - std::iota(data.begin(), data.end(), 1.0f); - - migraphx::module p1; - { - auto* mm = p1.get_main_module(); - auto input = mm->add_parameter("x", s); - auto lrn = mm->add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - input); - mm->add_return({lrn}); - p1.compile(migraphx::make_target("ref")); - } - - migraphx::module p2; - { - auto* mm = p2.get_main_module(); - auto input = mm->add_parameter("x", s); - auto lrn = mm->add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - input); - mm->add_return({lrn}); - - opt_pooling(*mm); - p2.compile(migraphx::make_target("ref")); - } - - migraphx::parameter_map params; - params["x"] = migraphx::argument(s, data.data()); - - auto result1 = p1.eval(params).back(); - auto result2 = p2.eval(params).back(); - - visit_all(result1, result2)([&](auto r1, auto r2) { - EXPECT(migraphx::verify::verify_rms_range(r1, r2)); - }); -} -*/ TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From ea7159d2fe8a99e32d7739b761d68c91d6f8ed5b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:03:24 +0000 Subject: [PATCH 19/64] accepting both even and odd sizes --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 07de6f11583..9c3525dc840 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -73,7 +73,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const int64_t rank = static_cast(lens.size()); int64_t caxis = axis < 0 ? axis + rank : axis; if(rank < 2 or caxis >= rank) return; - if(size <= 0 or (size % 2) == 0) return; + if(size <= 0) return; auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); From cc383cdf238d416f2f98b5eac93ebceb4a9c8b17 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:22:30 +0000 Subject: [PATCH 20/64] remove tidy errors --- src/rewrite_pooling.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 9c3525dc840..39895761ec9 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -69,10 +69,9 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); // e.g., NCHW + auto lens = xshape.lens(); const int64_t rank = static_cast(lens.size()); - int64_t caxis = axis < 0 ? axis + rank : axis; - if(rank < 2 or caxis >= rank) return; + if(rank < 2 or axis >= rank) return; if(size <= 0) return; auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); From 6d17986b717c74bf5c1edc581c10a6ee15aab841 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:23:41 +0000 Subject: [PATCH 21/64] remove tidy errors --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 39895761ec9..f572c192346 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -78,7 +78,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) std::vector perm(rank); std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[static_cast(caxis)], perm.back()); + std::swap(perm[static_cast(axis)], perm.back()); auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto moved_lens = moved->get_shape().lens(); From 506e2e6302bc11015cd3612d17c665c07d57e921 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:38:24 +0000 Subject: [PATCH 22/64] tidy error --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index f572c192346..ad119608a80 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -71,7 +71,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); auto lens = xshape.lens(); const int64_t rank = static_cast(lens.size()); - if(rank < 2 or axis >= rank) return; + if(rank < 2) return; if(size <= 0) return; auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); From 6e5597d491bea74ee1db2eb4fb12e36471b5ccf4 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:52:50 +0000 Subject: [PATCH 23/64] formatting --- src/rewrite_pooling.cpp | 67 +++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index ad119608a80..5d288a3ed27 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -61,57 +61,58 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) { auto v = ins->get_operator().to_value(); - float alpha = v.at("alpha").to(); - float beta = v.at("beta").to(); - float k = v.at("bias").to(); - int size = v.at("size").to(); + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); const unsigned int axis = 1; - auto x = ins->inputs().at(0); + auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); + auto lens = xshape.lens(); const int64_t rank = static_cast(lens.size()); - if(rank < 2) return; - if(size <= 0) return; + if(rank < 2) + return; + if(size <= 0) + return; auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); std::vector perm(rank); std::iota(perm.begin(), perm.end(), 0); std::swap(perm[static_cast(axis)], perm.back()); - auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto moved_lens = moved->get_shape().lens(); - - auto b = std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); + auto b = + std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); const int64_t c = static_cast(moved_lens.back()); - auto pooled_in = m.insert_instruction( - ins, - make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), - moved); - - auto avg = m.insert_instruction( + auto pooled_in = m.insert_instruction( ins, - make_op("pooling", - {{"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, size/2}}, - {"count_include_pad", true}}), - pooled_in); + make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), + moved); + + auto avg = m.insert_instruction(ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, size / 2}}, + {"count_include_pad", true}}), + pooled_in); - auto moved_shape_back = - std::vector(moved_lens.begin(), moved_lens.end()); - auto avg_moved = m.insert_instruction( - ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); + auto moved_shape_back = std::vector(moved_lens.begin(), moved_lens.end()); + auto avg_moved = + m.insert_instruction(ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); auto invp = invert_permutation(perm); - auto avg_ch = m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); + auto avg_ch = + m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); - auto k_lit = m.add_literal(k); - auto a_lit = m.add_literal(alpha); - auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); - auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg_ch); auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); From 393e613ee4b40c06b83b7098cc8c31bdccc24ce5 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 15:59:56 +0000 Subject: [PATCH 24/64] formatting --- test/rewrite_pooling_test.cpp | 65 +++++++++++++++++------------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1025dcb7db1..6e269d4f78a 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -314,62 +314,61 @@ TEST_CASE(rewrite_pooling_dialtions_test5) TEST_CASE(lower_lrn_to_pooling) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - migraphx::module m1; { - auto x = m1.add_parameter("x", s); + auto x = m1.add_parameter("x", s); auto lrn = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + migraphx::make_op("lrn", + {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), x); m1.add_return({lrn}); } opt_pooling(m1); - + migraphx::module m2; - { + { auto x = m2.add_parameter("x", s); - auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); + auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); auto transpose1 = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared); auto reshape1 = m2.add_instruction( migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1); - auto pooling = m2.add_instruction( - migraphx::make_op("pooling", { - {"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, 5}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, 2}}, - {"count_include_pad", true} - }), reshape1); - auto reshape2 = m2.add_instruction( - migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); + auto pooling = + m2.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, 5}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, 2}}, + {"count_include_pad", true}}), + reshape1); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); auto transpose2 = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2); - - auto beta_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.75}}); - auto alpha_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.0001}}); - auto bias_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {1.0}}); - - auto bias_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); - auto alpha_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); - auto beta_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); + + auto beta_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.75}}); + auto alpha_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.0001}}); + auto bias_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {1.0}}); + + auto bias_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); + auto alpha_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); + auto beta_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); - auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); - auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); + auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); + auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); m2.add_return({result}); } - + EXPECT(m1.sort() == m2.sort()); } - TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From 38a39c2ae497a0c0df052b1dde57e8f6e9a586e6 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 17:08:43 +0000 Subject: [PATCH 25/64] license --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 5d288a3ed27..2904103e91c 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 2b6cc85a083a0e39983d457db6a026636ab9885b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 17:13:17 +0000 Subject: [PATCH 26/64] formatting --- test/rewrite_pooling_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 6e269d4f78a..9f21b3ab794 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -361,7 +361,7 @@ TEST_CASE(lower_lrn_to_pooling) auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); - auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); + auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); m2.add_return({result}); } From d4d5452f87cc5492b6e8b24cbfc71b3532acc730 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 24 Sep 2025 17:19:32 +0000 Subject: [PATCH 27/64] reverting back to even size --- src/rewrite_pooling.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 2904103e91c..938d27a9176 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -73,7 +73,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const int64_t rank = static_cast(lens.size()); if(rank < 2) return; - if(size <= 0) + if(size <= 0 or (size % 2) != 0) return; auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); @@ -88,8 +88,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const int64_t c = static_cast(moved_lens.back()); auto pooled_in = m.insert_instruction( ins, - make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), - moved); + make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), + moved); auto avg = m.insert_instruction(ins, make_op("pooling", @@ -103,8 +103,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto moved_shape_back = std::vector(moved_lens.begin(), moved_lens.end()); auto avg_moved = m.insert_instruction(ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); - - auto invp = invert_permutation(perm); auto avg_ch = m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); From 5a783c9eebb8d2f991892615cf92e5919792d99c Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 26 Sep 2025 17:06:48 +0000 Subject: [PATCH 28/64] logic for both evenn and odd sizes --- src/rewrite_pooling.cpp | 110 ++++++++++++++++++++++++++- test/rewrite_pooling_test.cpp | 139 +++++++++++++++++++++++++++++++++- 2 files changed, 247 insertions(+), 2 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 938d27a9176..382497dcfad 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -57,7 +57,7 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } -static void lower_lrn_to_pooling(module& m, instruction_ref ins) +/*static void lower_lrn_to_pooling(module& m, instruction_ref ins) { auto v = ins->get_operator().to_value(); @@ -122,6 +122,114 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) m.replace_instruction(ins, y); } +*/ + +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + const unsigned int axis = 1; + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + auto lens = xshape.lens(); + const int64_t rank = static_cast(lens.size()); + + // Early validation + if(rank < 2) return; + if(size <= 0) return; + + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + instruction_ref avg; + + if(size % 2 == 0) { + // Even size: use direct channel-wise averaging with slice/reduce + std::vector channel_windows; + + for(int c = 0; c < static_cast(lens[1]); ++c) { + int start = std::max(0, c - size/2); + int end = std::min(static_cast(lens[1]), c + size/2 + 1); + + // Slice the channel range [start, end) + auto slice_start = m.insert_instruction(ins, + make_op("slice", {{"axes", {1}}, {"starts", {start}}, {"ends", {end}}}), x2); + + // Reduce mean along channel axis to get single value per spatial location + auto local_mean = m.insert_instruction(ins, + make_op("reduce_mean", {{"axes", {1}}}), slice_start); + + channel_windows.push_back(local_mean); + } + + // Concatenate all channel results back together + avg = m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), channel_windows); + } else { + // Odd size: use transpose/reshape/pooling approach with explicit padding + + // Create permutation for transpose: move channel dimension to last + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[axis], perm.back()); + + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + // Calculate moved lens for reshape + auto moved_lens = lens; + std::swap(moved_lens[axis], moved_lens.back()); + + // Flatten spatial dimensions for 1D pooling + auto batch_size = std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); + std::vector reshape_dims = {static_cast(batch_size), 1, 1, static_cast(moved_lens.back())}; + auto reshape1 = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), transpose1); + + // Use explicit pad operation for asymmetric padding + int64_t left_pad = (size - 1) / 2; // For size=5: left_pad = 2 + int64_t right_pad = size / 2; // For size=5: right_pad = 2 + + // Create padding vector: {batch, spatial1, spatial2, channel, batch, spatial1, spatial2, channel} + std::vector pad_values = {0, 0, 0, left_pad, 0, 0, 0, right_pad}; + + auto padded = m.insert_instruction(ins, + make_op("pad", {{"pads", pad_values}, {"value", 0.0f}}), reshape1); + + // Now use pooling with zero padding since we've already padded + auto pooled = m.insert_instruction(ins, + make_op("pooling", { + {"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, 0}}, + {"count_include_pad", true} + }), padded); + + // Reshape back to moved dimensions + auto reshape2 = m.insert_instruction(ins, make_op("reshape", {{"dims", moved_lens}}), pooled); + + // Transpose back to original layout + avg = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), reshape2); + } + + // Complete LRN formula: output = input / (bias + alpha * avg)^beta + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); +} static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 9f21b3ab794..f429e25a604 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -311,7 +311,7 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } -TEST_CASE(lower_lrn_to_pooling) +/*TEST_CASE(lower_lrn_to_pooling) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; migraphx::module m1; @@ -366,9 +366,146 @@ TEST_CASE(lower_lrn_to_pooling) m2.add_return({result}); } + EXPECT(m1.sort() == m2.sort()); +}*/ + + +TEST_CASE(lower_lrn_to_pooling_odd_size) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto lrn = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + m1.add_return({lrn}); + } + opt_pooling(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + + auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared); + auto reshape1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1); + + // Explicit padding for odd size (symmetric padding: 2 on each side) + auto pad_op = m2.add_instruction( + migraphx::make_op("pad", {{"mode", 0}, {"pads", {0, 0, 0, 2, 0, 0, 0, 2}}, {"value", 0.0f}}), + reshape1); + + // Pooling with zero padding since explicit padding is applied + auto pooling = m2.add_instruction( + migraphx::make_op("pooling", { + {"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, 5}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, 0}}, + {"count_include_pad", true} + }), pad_op); + + auto reshape2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); + + // Fixed transpose permutation to match actual implementation + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2); + + // LRN formula completion + auto bias_lit = m2.add_literal(1.0f); + auto alpha_lit = m2.add_literal(0.0001f); + auto beta_lit = m2.add_literal(0.75f); + + auto bias_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); + auto alpha_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); + auto beta_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); + auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); + auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); + auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); + + m2.add_return({result}); + } + EXPECT(m1.sort() == m2.sort()); } + +TEST_CASE(lower_lrn_to_pooling_even_size) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto lrn = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 4}}), + x); + m1.add_return({lrn}); + } + opt_pooling(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + + // Square the input + auto x2 = m2.add_instruction(migraphx::make_op("mul"), x, x); + + // Create channel windows manually (matching your implementation) + std::vector channel_windows; + + for(int c = 0; c < 64; ++c) { + int start = std::max(0, c - 2); // size/2 = 4/2 = 2 + int end = std::min(64, c + 2 + 1); // c + size/2 + 1 + + auto slice_start = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {start}}, {"ends", {end}}}), x2); + + auto local_mean = m2.add_instruction( + migraphx::make_op("reduce_mean", {{"axes", {1}}}), slice_start); + + + channel_windows.push_back(local_mean); + } + + // Concatenate all results + auto avg = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), channel_windows); + + // Complete LRN formula + auto k_lit = m2.add_literal(1.0f); // bias + auto a_lit = m2.add_literal(0.0001f); // alpha + auto b_lit = m2.add_literal(0.75f); // beta + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, avg); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto result = m2.add_instruction(migraphx::make_op("div"), x, denpow); + + m2.add_return({result}); + } + + EXPECT(m1.sort() == m2.sort()); +} + + + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From 02a7ce74fff2a7ea19ddf66b3fa6137c7305b47c Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 17:39:57 +0000 Subject: [PATCH 29/64] calculate padding --- src/rewrite_pooling.cpp | 216 +++++++++++++++++++++------------------- 1 file changed, 112 insertions(+), 104 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 382497dcfad..c9caa78f405 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -124,111 +124,119 @@ static void replace_with_reduce(module& m, instruction_ref ins) */ -static void lower_lrn_to_pooling(module& m, instruction_ref ins) -{ - auto v = ins->get_operator().to_value(); - - float alpha = v.at("alpha").to(); - float beta = v.at("beta").to(); - float k = v.at("bias").to(); - int size = v.at("size").to(); - const unsigned int axis = 1; - - auto x = ins->inputs().at(0); - const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); - const int64_t rank = static_cast(lens.size()); - - // Early validation - if(rank < 2) return; - if(size <= 0) return; - - auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); - - instruction_ref avg; - - if(size % 2 == 0) { - // Even size: use direct channel-wise averaging with slice/reduce - std::vector channel_windows; - - for(int c = 0; c < static_cast(lens[1]); ++c) { - int start = std::max(0, c - size/2); - int end = std::min(static_cast(lens[1]), c + size/2 + 1); - - // Slice the channel range [start, end) - auto slice_start = m.insert_instruction(ins, - make_op("slice", {{"axes", {1}}, {"starts", {start}}, {"ends", {end}}}), x2); - - // Reduce mean along channel axis to get single value per spatial location - auto local_mean = m.insert_instruction(ins, - make_op("reduce_mean", {{"axes", {1}}}), slice_start); - - channel_windows.push_back(local_mean); - } - - // Concatenate all channel results back together - avg = m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), channel_windows); - } else { - // Odd size: use transpose/reshape/pooling approach with explicit padding - // Create permutation for transpose: move channel dimension to last - std::vector perm(rank); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[axis], perm.back()); - - auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); - - // Calculate moved lens for reshape - auto moved_lens = lens; - std::swap(moved_lens[axis], moved_lens.back()); - - // Flatten spatial dimensions for 1D pooling - auto batch_size = std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); - std::vector reshape_dims = {static_cast(batch_size), 1, 1, static_cast(moved_lens.back())}; - auto reshape1 = m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), transpose1); - - // Use explicit pad operation for asymmetric padding - int64_t left_pad = (size - 1) / 2; // For size=5: left_pad = 2 - int64_t right_pad = size / 2; // For size=5: right_pad = 2 - - // Create padding vector: {batch, spatial1, spatial2, channel, batch, spatial1, spatial2, channel} - std::vector pad_values = {0, 0, 0, left_pad, 0, 0, 0, right_pad}; - - auto padded = m.insert_instruction(ins, - make_op("pad", {{"pads", pad_values}, {"value", 0.0f}}), reshape1); - - // Now use pooling with zero padding since we've already padded - auto pooled = m.insert_instruction(ins, - make_op("pooling", { - {"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, 0}}, - {"count_include_pad", true} - }), padded); - - // Reshape back to moved dimensions - auto reshape2 = m.insert_instruction(ins, make_op("reshape", {{"dims", moved_lens}}), pooled); - - // Transpose back to original layout - avg = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), reshape2); - } - - // Complete LRN formula: output = input / (bias + alpha * avg)^beta - auto k_lit = m.add_literal(k); - auto a_lit = m.add_literal(alpha); - auto b_lit = m.add_literal(beta); - - auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); - auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); - auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); - - auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg); - auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); - auto y = m.insert_instruction(ins, make_op("div"), x, denpow); - - m.replace_instruction(ins, y); +//#include // For calculate_padding function + +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + auto lens = xshape.lens(); + + // Early validation + if(lens.size() < 2) { + return; + } + if(size <= 0) { + return; + } + + // Support both even and odd sizes now + // Previously only even sizes were supported + + // Step 1: Square the input + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + // Step 2: Transpose NCHW -> NHWC for channel-wise pooling + std::vector perm = {0, 2, 3, 1}; // NCHW -> NHWC + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + auto transposed_shape = transpose1->get_shape(); + auto transposed_lens = transposed_shape.lens(); + + // Step 3: Calculate padding using calculate_padding function + int64_t channel_dim = transposed_lens[3]; // Channel dimension in NHWC + std::vector calculated_pads; + calculated_pads.resize(2, 0); // Pre-size for 1D padding + + calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); + + // Step 4: Apply direct pooling with calculated padding + instruction_ref avg; + try { + avg = m.insert_instruction(ins, + make_op("pooling", { + {"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, // 2D: [height, width] + {"stride", std::vector{1, 1}}, // 2D: [height, width] + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, // 4 elements + {"dilations", std::vector{1, 1}}, // 2D: [height, width] + {"count_include_pad", true} + }), transpose1); + + auto avg_shape = avg->get_shape(); + auto avg_lens = avg_shape.lens(); + + // Validate dimensions are preserved + if(avg_lens[3] != transposed_lens[3]) { + return; + } + + } catch(const std::exception& e) { + return; + } + + // Step 5: Transpose back NHWC -> NCHW + std::vector inv_perm = {0, 3, 1, 2}; // NHWC -> NCHW + auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + + auto final_shape = transpose2->get_shape(); + auto final_lens = final_shape.lens(); + + // Check if final shape matches input shape + bool shape_matches = true; + if(final_lens.size() != lens.size()) { + shape_matches = false; + } else { + for(size_t i = 0; i < lens.size(); ++i) { + if(final_lens[i] != lens[i]) { + shape_matches = false; + break; + } + } + } + + if(!shape_matches) { + return; + } + + // Step 6: Complete LRN formula + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + try { + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); + + } catch(const std::exception& e) { + return; + } } From 84d2105ff0dc9041951ac0a6126035938cc2bcc1 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 17:57:32 +0000 Subject: [PATCH 30/64] formatting --- src/rewrite_pooling.cpp | 100 ++++------------------------------------ 1 file changed, 9 insertions(+), 91 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index c9caa78f405..556950730e8 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -57,76 +57,6 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } -/*static void lower_lrn_to_pooling(module& m, instruction_ref ins) -{ - auto v = ins->get_operator().to_value(); - - float alpha = v.at("alpha").to(); - float beta = v.at("beta").to(); - float k = v.at("bias").to(); - int size = v.at("size").to(); - const unsigned int axis = 1; - - auto x = ins->inputs().at(0); - const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); - const int64_t rank = static_cast(lens.size()); - if(rank < 2) - return; - if(size <= 0 or (size % 2) != 0) - return; - - auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); - - std::vector perm(rank); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[static_cast(axis)], perm.back()); - auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); - auto moved_lens = moved->get_shape().lens(); - auto b = - std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies()); - const int64_t c = static_cast(moved_lens.back()); - auto pooled_in = m.insert_instruction( - ins, - make_op("reshape", {{"dims", std::vector{static_cast(b), 1, 1, c}}}), - moved); - - auto avg = m.insert_instruction(ins, - make_op("pooling", - {{"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, size / 2}}, - {"count_include_pad", true}}), - pooled_in); - - auto moved_shape_back = std::vector(moved_lens.begin(), moved_lens.end()); - auto avg_moved = - m.insert_instruction(ins, make_op("reshape", {{"dims", moved_shape_back}}), avg); - auto invp = invert_permutation(perm); - auto avg_ch = - m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved); - - auto k_lit = m.add_literal(k); - auto a_lit = m.add_literal(alpha); - auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); - auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); - auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg_ch); - auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - - auto b_lit = m.add_literal(beta); - auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); - auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); - auto y = m.insert_instruction(ins, make_op("div"), ins->inputs().front(), denpow); - - m.replace_instruction(ins, y); -} - -*/ - - -//#include // For calculate_padding function - static void lower_lrn_to_pooling(module& m, instruction_ref ins) { auto v = ins->get_operator().to_value(); @@ -140,7 +70,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); auto lens = xshape.lens(); - // Early validation if(lens.size() < 2) { return; } @@ -148,43 +77,35 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) return; } - // Support both even and odd sizes now - // Previously only even sizes were supported - - // Step 1: Square the input auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); - // Step 2: Transpose NCHW -> NHWC for channel-wise pooling - std::vector perm = {0, 2, 3, 1}; // NCHW -> NHWC + std::vector perm = {0, 2, 3, 1}; auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto transposed_shape = transpose1->get_shape(); auto transposed_lens = transposed_shape.lens(); - // Step 3: Calculate padding using calculate_padding function - int64_t channel_dim = transposed_lens[3]; // Channel dimension in NHWC + int64_t channel_dim = transposed_lens[3]; std::vector calculated_pads; - calculated_pads.resize(2, 0); // Pre-size for 1D padding + calculated_pads.resize(2, 0); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); - // Step 4: Apply direct pooling with calculated padding instruction_ref avg; try { avg = m.insert_instruction(ins, make_op("pooling", { {"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, // 2D: [height, width] - {"stride", std::vector{1, 1}}, // 2D: [height, width] - {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, // 4 elements - {"dilations", std::vector{1, 1}}, // 2D: [height, width] - {"count_include_pad", true} - }), transpose1); + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true} + }), transpose1); auto avg_shape = avg->get_shape(); auto avg_lens = avg_shape.lens(); - // Validate dimensions are preserved if(avg_lens[3] != transposed_lens[3]) { return; } @@ -193,14 +114,12 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) return; } - // Step 5: Transpose back NHWC -> NCHW std::vector inv_perm = {0, 3, 1, 2}; // NHWC -> NCHW auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); auto final_shape = transpose2->get_shape(); auto final_lens = final_shape.lens(); - // Check if final shape matches input shape bool shape_matches = true; if(final_lens.size() != lens.size()) { shape_matches = false; @@ -217,7 +136,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) return; } - // Step 6: Complete LRN formula auto k_lit = m.add_literal(k); auto a_lit = m.add_literal(alpha); auto b_lit = m.add_literal(beta); From f4df6244ccef0b90059347c59a61a13f91333cc7 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 18:02:47 +0000 Subject: [PATCH 31/64] formatting --- src/rewrite_pooling.cpp | 100 ---------------------------------------- 1 file changed, 100 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 556950730e8..daef8ae1e6d 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -57,106 +57,6 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } -static void lower_lrn_to_pooling(module& m, instruction_ref ins) -{ - auto v = ins->get_operator().to_value(); - - float alpha = v.at("alpha").to(); - float beta = v.at("beta").to(); - float k = v.at("bias").to(); - int size = v.at("size").to(); - - auto x = ins->inputs().at(0); - const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); - - if(lens.size() < 2) { - return; - } - if(size <= 0) { - return; - } - - auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); - - std::vector perm = {0, 2, 3, 1}; - auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); - - auto transposed_shape = transpose1->get_shape(); - auto transposed_lens = transposed_shape.lens(); - - int64_t channel_dim = transposed_lens[3]; - std::vector calculated_pads; - calculated_pads.resize(2, 0); - - calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); - - instruction_ref avg; - try { - avg = m.insert_instruction(ins, - make_op("pooling", { - {"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, - {"dilations", std::vector{1, 1}}, - {"count_include_pad", true} - }), transpose1); - - auto avg_shape = avg->get_shape(); - auto avg_lens = avg_shape.lens(); - - if(avg_lens[3] != transposed_lens[3]) { - return; - } - - } catch(const std::exception& e) { - return; - } - - std::vector inv_perm = {0, 3, 1, 2}; // NHWC -> NCHW - auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); - - auto final_shape = transpose2->get_shape(); - auto final_lens = final_shape.lens(); - - bool shape_matches = true; - if(final_lens.size() != lens.size()) { - shape_matches = false; - } else { - for(size_t i = 0; i < lens.size(); ++i) { - if(final_lens[i] != lens[i]) { - shape_matches = false; - break; - } - } - } - - if(!shape_matches) { - return; - } - - auto k_lit = m.add_literal(k); - auto a_lit = m.add_literal(alpha); - auto b_lit = m.add_literal(beta); - - auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); - auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); - auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); - - try { - auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); - auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); - auto y = m.insert_instruction(ins, make_op("div"), x, denpow); - - m.replace_instruction(ins, y); - - } catch(const std::exception& e) { - return; - } -} - static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { From 89fc424587b1b90960661e6391c14dec7935596a Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 18:03:26 +0000 Subject: [PATCH 32/64] formatting --- src/rewrite_pooling.cpp | 100 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index daef8ae1e6d..6e27dd23f71 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -57,6 +57,106 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + auto lens = xshape.lens(); + + if(lens.size() < 2) { + return; + } + if(size <= 0) { + return; + } + + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + std::vector perm = {0, 2, 3, 1}; + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + auto transposed_shape = transpose1->get_shape(); + auto transposed_lens = transposed_shape.lens(); + + int64_t channel_dim = transposed_lens[3]; + std::vector calculated_pads; + calculated_pads.resize(2, 0); + + calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); + + instruction_ref avg; + try { + avg = m.insert_instruction(ins, + make_op("pooling", { + {"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true} + }), transpose1); + + auto avg_shape = avg->get_shape(); + auto avg_lens = avg_shape.lens(); + + if(avg_lens[3] != transposed_lens[3]) { + return; + } + + } catch(const std::exception& e) { + return; + } + + std::vector inv_perm = {0, 3, 1, 2}; + auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + + auto final_shape = transpose2->get_shape(); + auto final_lens = final_shape.lens(); + + + bool shape_matches = true; + if(final_lens.size() != lens.size()) { + shape_matches = false; + } else { + for(size_t i = 0; i < lens.size(); ++i) { + if(final_lens[i] != lens[i]) { + shape_matches = false; + break; + } + } + } + + if(!shape_matches) { + return; + } + + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + try { + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); + + } catch(const std::exception& e) { + return; + } +} static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { From e489de0a25e2dff79faf8324dc755ddf175e5c2c Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 18:36:07 +0000 Subject: [PATCH 33/64] test case added --- test/rewrite_pooling_test.cpp | 61 +++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index f429e25a604..85a793d5b3b 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -369,7 +369,7 @@ TEST_CASE(rewrite_pooling_dialtions_test5) EXPECT(m1.sort() == m2.sort()); }*/ - +/* TEST_CASE(lower_lrn_to_pooling_odd_size) { migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; @@ -503,8 +503,63 @@ TEST_CASE(lower_lrn_to_pooling_even_size) EXPECT(m1.sort() == m2.sort()); } - - +*/ + +TEST_CASE(test_lower_lrn_to_pooling_transformation) +{ + migraphx::module m1; + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + + migraphx::module m2; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); + + std::vector expected_pads = {1, 2}; // This would be the result of calculate_padding + + auto avg = m2.add_instruction( + migraphx::make_op("pooling", { + {"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, 4}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true} + }), transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + + opt_pooling(m1); + + EXPECT(m1 == m2); +} TEST_CASE(rewrite_avgpool_rank3_dil_test) { From b0335f293b21ce23cbb8703955e1b24577b4d957 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 18:54:33 +0000 Subject: [PATCH 34/64] verify test case --- test/verify/test_relu_lrn.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index feec2ab21c6..de655f06ccc 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -42,3 +42,19 @@ struct test_relu_lrn : verify_program return p; } }; + + +struct test_lrn_to_pooling : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 32, 28, 28}}); + mm->add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), + x); + return p; + } +}; From 77ca91d26982d300c398b4fe16ae2b75ea3805d0 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 18:58:17 +0000 Subject: [PATCH 35/64] formatting --- test/rewrite_pooling_test.cpp | 197 +--------------------------------- 1 file changed, 2 insertions(+), 195 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 85a793d5b3b..7aa662355a9 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -311,201 +311,8 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } -/*TEST_CASE(lower_lrn_to_pooling) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - migraphx::module m1; - { - auto x = m1.add_parameter("x", s); - auto lrn = m1.add_instruction( - migraphx::make_op("lrn", - {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - m1.add_return({lrn}); - } - opt_pooling(m1); - - migraphx::module m2; - { - auto x = m2.add_parameter("x", s); - - auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); - auto transpose1 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared); - auto reshape1 = m2.add_instruction( - migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1); - auto pooling = - m2.add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, 5}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, 2}}, - {"count_include_pad", true}}), - reshape1); - auto reshape2 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); - auto transpose2 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2); - - auto beta_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.75}}); - auto alpha_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.0001}}); - auto bias_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {1.0}}); - - auto bias_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); - auto alpha_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); - auto beta_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); - auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); - auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); - auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); - - m2.add_return({result}); - } - - EXPECT(m1.sort() == m2.sort()); -}*/ - -/* -TEST_CASE(lower_lrn_to_pooling_odd_size) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - - migraphx::module m1; - { - auto x = m1.add_parameter("x", s); - auto lrn = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - m1.add_return({lrn}); - } - opt_pooling(m1); - - migraphx::module m2; - { - auto x = m2.add_parameter("x", s); - - auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x); - auto transpose1 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared); - auto reshape1 = m2.add_instruction( - migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1); - - // Explicit padding for odd size (symmetric padding: 2 on each side) - auto pad_op = m2.add_instruction( - migraphx::make_op("pad", {{"mode", 0}, {"pads", {0, 0, 0, 2, 0, 0, 0, 2}}, {"value", 0.0f}}), - reshape1); - - // Pooling with zero padding since explicit padding is applied - auto pooling = m2.add_instruction( - migraphx::make_op("pooling", { - {"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, 5}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, 0}}, - {"count_include_pad", true} - }), pad_op); - - auto reshape2 = m2.add_instruction( - migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling); - - // Fixed transpose permutation to match actual implementation - auto transpose2 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2); - - // LRN formula completion - auto bias_lit = m2.add_literal(1.0f); - auto alpha_lit = m2.add_literal(0.0001f); - auto beta_lit = m2.add_literal(0.75f); - - auto bias_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit); - auto alpha_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit); - auto beta_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2); - auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg); - auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb); - auto result = m2.add_instruction(migraphx::make_op("div"), x, powered); - - m2.add_return({result}); - } - - EXPECT(m1.sort() == m2.sort()); -} - - -TEST_CASE(lower_lrn_to_pooling_even_size) -{ - migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}}; - - migraphx::module m1; - { - auto x = m1.add_parameter("x", s); - auto lrn = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 4}}), - x); - m1.add_return({lrn}); - } - opt_pooling(m1); - - migraphx::module m2; - { - auto x = m2.add_parameter("x", s); - - // Square the input - auto x2 = m2.add_instruction(migraphx::make_op("mul"), x, x); - - // Create channel windows manually (matching your implementation) - std::vector channel_windows; - - for(int c = 0; c < 64; ++c) { - int start = std::max(0, c - 2); // size/2 = 4/2 = 2 - int end = std::min(64, c + 2 + 1); // c + size/2 + 1 - - auto slice_start = m2.add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {start}}, {"ends", {end}}}), x2); - - auto local_mean = m2.add_instruction( - migraphx::make_op("reduce_mean", {{"axes", {1}}}), slice_start); - - - channel_windows.push_back(local_mean); - } - - // Concatenate all results - auto avg = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), channel_windows); - - // Complete LRN formula - auto k_lit = m2.add_literal(1.0f); // bias - auto a_lit = m2.add_literal(0.0001f); // alpha - auto b_lit = m2.add_literal(0.75f); // beta - - auto k_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), k_lit); - auto a_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), a_lit); - auto b_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), b_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, avg); - auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); - auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); - auto result = m2.add_instruction(migraphx::make_op("div"), x, denpow); - - m2.add_return({result}); - } - - EXPECT(m1.sort() == m2.sort()); -} -*/ -TEST_CASE(test_lower_lrn_to_pooling_transformation) +TEST_CASE(test_lower_lrn_to_pooling) { migraphx::module m1; migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; @@ -523,7 +330,7 @@ TEST_CASE(test_lower_lrn_to_pooling_transformation) auto transpose1 = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); - std::vector expected_pads = {1, 2}; // This would be the result of calculate_padding + std::vector expected_pads = {1, 2}; auto avg = m2.add_instruction( migraphx::make_op("pooling", { From e8c2547347ae526a65efd7ee82fc1cfd58750666 Mon Sep 17 00:00:00 2001 From: Aarushi Jain <142941703+aarushjain29@users.noreply.github.com> Date: Tue, 30 Sep 2025 14:46:30 -0500 Subject: [PATCH 36/64] Update test/rewrite_pooling_test.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/rewrite_pooling_test.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 7aa662355a9..71192b3b0f4 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -330,12 +330,16 @@ TEST_CASE(test_lower_lrn_to_pooling) auto transpose1 = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); - std::vector expected_pads = {1, 2}; + // Calculate padding based on LRN size parameter (size = 4) + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; // 1 + int64_t pad_right = lrn_size - 1 - pad_left; // 2 + std::vector expected_pads = {pad_left, pad_right}; auto avg = m2.add_instruction( migraphx::make_op("pooling", { {"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, 4}}, + {"lengths", std::vector{1, lrn_size}}, {"stride", std::vector{1, 1}}, {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, {"dilations", std::vector{1, 1}}, From 32821aa0ce93ae1dad4073c7f61833806cea4370 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Tue, 30 Sep 2025 20:14:44 +0000 Subject: [PATCH 37/64] licensing --- test/verify/test_relu_lrn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index de655f06ccc..6a297082c5a 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 50b429f47c5c352a93b4df017e166b1e3b5b6716 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 2 Oct 2025 13:50:35 +0000 Subject: [PATCH 38/64] combine line 89 and 90 --- src/rewrite_pooling.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6e27dd23f71..17f67ce7187 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -86,9 +86,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transposed_lens = transposed_shape.lens(); int64_t channel_dim = transposed_lens[3]; - std::vector calculated_pads; - calculated_pads.resize(2, 0); - + std::vector calculated_pads(2); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); instruction_ref avg; From 5e51bc1582b58e26a394a1ba60309998523a8624 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 2 Oct 2025 13:56:59 +0000 Subject: [PATCH 39/64] compiler warning unused param --- src/rewrite_pooling.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 17f67ce7187..dd88dc50231 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -90,8 +90,9 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); instruction_ref avg; - try { - avg = m.insert_instruction(ins, + + try{ + avg = m.insert_instruction(ins, make_op("pooling", { {"mode", op::pooling_mode::average}, {"lengths", std::vector{1, size}}, @@ -108,9 +109,9 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) return; } - } catch(const std::exception& e) { - return; - } + } catch(const std::exception&) { + return; + } std::vector inv_perm = {0, 3, 1, 2}; auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); @@ -151,7 +152,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) m.replace_instruction(ins, y); - } catch(const std::exception& e) { + } catch(const std::exception&) { return; } } From e0355c57b799d85909cd9d568973eea71dac3ef3 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 2 Oct 2025 14:23:05 +0000 Subject: [PATCH 40/64] remove transposed lens --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index dd88dc50231..7f40da2a161 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -85,7 +85,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transposed_shape = transpose1->get_shape(); auto transposed_lens = transposed_shape.lens(); - int64_t channel_dim = transposed_lens[3]; + int64_t channel_dim = lens[1]; std::vector calculated_pads(2); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); From b8fc4eea06bf02bf67bef694d99fb3ff7278ae92 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 2 Oct 2025 14:32:14 +0000 Subject: [PATCH 41/64] Adding the check for size and combining all the checks in if --- src/rewrite_pooling.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 7f40da2a161..70ede5fdd64 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -70,13 +70,19 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); auto lens = xshape.lens(); - if(lens.size() < 2) { +/* if(lens.size() != 4) { return; } if(size <= 0) { return; } - + if(size > lens[1]) { + return; + } +*/ + if(lens.size() != 4 || size <= 0 || size > lens[1]) { + return; + } auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); std::vector perm = {0, 2, 3, 1}; From 1b073b9144c031875409d4253c3661ea2ea9f89b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 2 Oct 2025 17:02:11 +0000 Subject: [PATCH 42/64] changing the test to simplify_algebra like test --- test/rewrite_pooling_test.cpp | 108 ++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 71192b3b0f4..1f69a26cbe7 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -312,66 +312,70 @@ TEST_CASE(rewrite_pooling_dialtions_test5) } -TEST_CASE(test_lower_lrn_to_pooling) +TEST_CASE(test_lower_lrn_to_pooling) { migraphx::module m1; - migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; - auto input1 = m1.add_parameter("x", input_shape); - auto lrn1 = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), - input1); - m1.add_return({lrn1}); + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + } + opt_pooling(m1); migraphx::module m2; - auto input2 = m2.add_parameter("x", input_shape); - - auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); - - auto transpose1 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); - - // Calculate padding based on LRN size parameter (size = 4) - int64_t lrn_size = 4; - int64_t pad_left = (lrn_size - 1) / 2; // 1 - int64_t pad_right = lrn_size - 1 - pad_left; // 2 - std::vector expected_pads = {pad_left, pad_right}; - - auto avg = m2.add_instruction( - migraphx::make_op("pooling", { - {"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, lrn_size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, - {"dilations", std::vector{1, 1}}, - {"count_include_pad", true} - }), transpose1); - - auto transpose2 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), avg); - - auto k_lit = m2.add_literal(1.0f); - auto a_lit = m2.add_literal(0.0001f); - auto b_lit = m2.add_literal(0.75f); + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); + + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; + int64_t pad_right = lrn_size - 1 - pad_left; + std::vector expected_pads = {pad_left, pad_right}; + + auto avg = m2.add_instruction( + migraphx::make_op("pooling", { + {"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, lrn_size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true} + }), transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + } - auto k_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); - auto a_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); - auto b_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); - auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); - auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); - auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); - - m2.add_return({y}); - - opt_pooling(m1); - EXPECT(m1 == m2); } + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D From 72934da4cc7d140e06510c13220affff7c757e06 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 3 Oct 2025 21:49:06 +0000 Subject: [PATCH 43/64] tidy error --- src/rewrite_pooling.cpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 70ede5fdd64..3ee69959675 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -70,17 +70,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); auto lens = xshape.lens(); -/* if(lens.size() != 4) { - return; - } - if(size <= 0) { - return; - } - if(size > lens[1]) { - return; - } -*/ - if(lens.size() != 4 || size <= 0 || size > lens[1]) { + if(lens.size() != 4 or size <= 0 or size > lens[1]) { return; } auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); From ea31d5e658a7539c2f1c544637321939d5ce4450 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 3 Oct 2025 22:02:33 +0000 Subject: [PATCH 44/64] tidy error --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 3ee69959675..caa05c4d56c 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -79,7 +79,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); auto transposed_shape = transpose1->get_shape(); - auto transposed_lens = transposed_shape.lens(); + const auto& transposed_lens = transposed_shape.lens(); int64_t channel_dim = lens[1]; std::vector calculated_pads(2); From 3d7b4800365f53ede525a8bd2957e27f87d2f473 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 3 Oct 2025 22:06:36 +0000 Subject: [PATCH 45/64] tidy error --- src/rewrite_pooling.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index caa05c4d56c..faf07cd726b 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -99,7 +99,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) }), transpose1); auto avg_shape = avg->get_shape(); - auto avg_lens = avg_shape.lens(); + const auto& avg_lens = avg_shape.lens(); if(avg_lens[3] != transposed_lens[3]) { return; @@ -113,7 +113,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); auto final_shape = transpose2->get_shape(); - auto final_lens = final_shape.lens(); + const auto& final_lens = final_shape.lens(); bool shape_matches = true; From 080ac818b62e6005c91f98586c52fef768f9a91b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Fri, 3 Oct 2025 22:41:25 +0000 Subject: [PATCH 46/64] tidy error --- src/rewrite_pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index faf07cd726b..5cff48278cd 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -128,7 +128,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) } } - if(!shape_matches) { + if(not shape_matches) { return; } From 91eb3c7c9763f6c2a0b64c300125fb3998aa1a1b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 16:50:14 +0000 Subject: [PATCH 47/64] remove try catch and add all conditions --- src/rewrite_pooling.cpp | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 5cff48278cd..a3850ceac75 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -68,7 +68,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); - auto lens = xshape.lens(); + const auto& lens = xshape.lens(); if(lens.size() != 4 or size <= 0 or size > lens[1]) { return; @@ -81,14 +81,19 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transposed_shape = transpose1->get_shape(); const auto& transposed_lens = transposed_shape.lens(); + if(transposed_lens.size() != 4) { + return; + } + int64_t channel_dim = lens[1]; std::vector calculated_pads(2); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); - - instruction_ref avg; - - try{ - avg = m.insert_instruction(ins, + + if(calculated_pads[0] < 0 or calculated_pads[1] < 0) { + return; + } + + auto avg = m.insert_instruction(ins, make_op("pooling", { {"mode", op::pooling_mode::average}, {"lengths", std::vector{1, size}}, @@ -100,22 +105,17 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto avg_shape = avg->get_shape(); const auto& avg_lens = avg_shape.lens(); - - if(avg_lens[3] != transposed_lens[3]) { - return; - } - - } catch(const std::exception&) { - return; - } - + + if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) { + return; + } + std::vector inv_perm = {0, 3, 1, 2}; auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); auto final_shape = transpose2->get_shape(); const auto& final_lens = final_shape.lens(); - bool shape_matches = true; if(final_lens.size() != lens.size()) { shape_matches = false; @@ -140,7 +140,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); - try { auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); @@ -148,9 +147,6 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) m.replace_instruction(ins, y); - } catch(const std::exception&) { - return; - } } static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) From a351f311fc5133e78d84d08b5975d1446ab06636 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 16:50:35 +0000 Subject: [PATCH 48/64] formatting --- src/rewrite_pooling.cpp | 82 +++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index a3850ceac75..c5b5bc56d76 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -70,9 +70,11 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& xshape = x->get_shape(); const auto& lens = xshape.lens(); - if(lens.size() != 4 or size <= 0 or size > lens[1]) { - return; + if(lens.size() != 4 or size <= 0 or size > lens[1]) + { + return; } + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); std::vector perm = {0, 2, 3, 1}; @@ -81,35 +83,38 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transposed_shape = transpose1->get_shape(); const auto& transposed_lens = transposed_shape.lens(); - if(transposed_lens.size() != 4) { - return; - } + if(transposed_lens.size() != 4) + { + return; + } int64_t channel_dim = lens[1]; std::vector calculated_pads(2); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); - - if(calculated_pads[0] < 0 or calculated_pads[1] < 0) { - return; + + if(calculated_pads[0] < 0 or calculated_pads[1] < 0) + { + return; } - + auto avg = m.insert_instruction(ins, - make_op("pooling", { - {"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, - {"dilations", std::vector{1, 1}}, - {"count_include_pad", true} - }), transpose1); - - auto avg_shape = avg->get_shape(); - const auto& avg_lens = avg_shape.lens(); - - if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) { - return; - } - + make_op("pooling", { + {"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true} + }), transpose1); + + auto avg_shape = avg->get_shape(); + const auto& avg_lens = avg_shape.lens(); + + if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) + { + return; + } + std::vector inv_perm = {0, 3, 1, 2}; auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); @@ -117,18 +122,24 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) const auto& final_lens = final_shape.lens(); bool shape_matches = true; - if(final_lens.size() != lens.size()) { + if(final_lens.size() != lens.size()) + { shape_matches = false; - } else { - for(size_t i = 0; i < lens.size(); ++i) { - if(final_lens[i] != lens[i]) { + } + else + { + for(size_t i = 0; i < lens.size(); ++i) + { + if(final_lens[i] != lens[i]) + { shape_matches = false; break; } } } - if(not shape_matches) { + if(not shape_matches) + { return; } @@ -140,13 +151,12 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); - auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); - auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); - auto y = m.insert_instruction(ins, make_op("div"), x, denpow); - - m.replace_instruction(ins, y); + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + m.replace_instruction(ins, y); } static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) From 310c8e7173975af78b0597e5180de90d71424895 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 17:26:19 +0000 Subject: [PATCH 49/64] new tests added --- test/verify/test_lrn.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 test/verify/test_lrn.cpp diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp new file mode 100644 index 00000000000..96177e16f5b --- /dev/null +++ b/test/verify/test_lrn.cpp @@ -0,0 +1,25 @@ +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_lrn : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); + mm->add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), + x); + return p; + } +}; + +template struct test_lrn<32, 6>; +template struct test_lrn<32, 5>; +template struct test_lrn<31, 8>; +template struct test_lrn<31, 5>; From 8745a745a42c6127a24ef3efebcc62e750bb1f5b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 17:27:22 +0000 Subject: [PATCH 50/64] removing test case from test_relu_lrn --- test/verify/test_relu_lrn.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index 6a297082c5a..d201bd77317 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -42,19 +42,3 @@ struct test_relu_lrn : verify_program return p; } }; - - -struct test_lrn_to_pooling : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 32, 28, 28}}); - mm->add_instruction( - migraphx::make_op("lrn", - {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}), - x); - return p; - } -}; From 808a56e2b8d7c96f9e9d36317a8b62dd99c84dc7 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 17:57:32 +0000 Subject: [PATCH 51/64] MIGRAPHX_REWRITE_LRN flag --- src/include/migraphx/rewrite_pooling.hpp | 1 + src/rewrite_pooling.cpp | 10 +++++----- src/targets/gpu/target.cpp | 3 ++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index ebef9834786..46105c10304 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -38,6 +38,7 @@ struct module; */ struct MIGRAPHX_EXPORT rewrite_pooling { + bool rewrite_lrn = false; std::string name() const { return "rewrite_pooling"; } void apply(module& m) const; }; diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index c5b5bc56d76..12480068951 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -244,15 +244,15 @@ static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins } void rewrite_pooling::apply(module& m) const -{ +{ for(auto ins : iterator_for(m)) { if(ins->inputs().empty()) continue; - if(ins->name() == "lrn") - { - lower_lrn_to_pooling(m, ins); - continue; + if(rewrite_lrn and ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; } if(ins->name() != "pooling") continue; diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 85b3698c8ab..5844c934259 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -84,6 +84,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif @@ -203,7 +204,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{}, + rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}, dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, From 355bec1ac7f1039db82fa32e02128f0b165e44f0 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 17:59:13 +0000 Subject: [PATCH 52/64] license --- test/verify/test_lrn.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp index 96177e16f5b..38188343fe6 100644 --- a/test/verify/test_lrn.cpp +++ b/test/verify/test_lrn.cpp @@ -1,3 +1,27 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + #include "verify_program.hpp" #include #include From 6d8929f822535f2518c07e2209f5575c334510be Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 18:02:55 +0000 Subject: [PATCH 53/64] formatting --- src/rewrite_pooling.cpp | 52 +++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 12480068951..f33f84e6d11 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -66,9 +66,9 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) float k = v.at("bias").to(); int size = v.at("size").to(); - auto x = ins->inputs().at(0); + auto x = ins->inputs().at(0); const auto& xshape = x->get_shape(); - const auto& lens = xshape.lens(); + const auto& lens = xshape.lens(); if(lens.size() != 4 or size <= 0 or size > lens[1]) { @@ -80,7 +80,7 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) std::vector perm = {0, 2, 3, 1}; auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); - auto transposed_shape = transpose1->get_shape(); + auto transposed_shape = transpose1->get_shape(); const auto& transposed_lens = transposed_shape.lens(); if(transposed_lens.size() != 4) @@ -97,17 +97,18 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) return; } - auto avg = m.insert_instruction(ins, - make_op("pooling", { - {"mode", op::pooling_mode::average}, - {"lengths", std::vector{1, size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, - {"dilations", std::vector{1, 1}}, - {"count_include_pad", true} - }), transpose1); - - auto avg_shape = avg->get_shape(); + auto avg = m.insert_instruction( + ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto avg_shape = avg->get_shape(); const auto& avg_lens = avg_shape.lens(); if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) @@ -116,9 +117,10 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) } std::vector inv_perm = {0, 3, 1, 2}; - auto transpose2 = m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + auto transpose2 = + m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); - auto final_shape = transpose2->get_shape(); + auto final_shape = transpose2->get_shape(); const auto& final_lens = final_shape.lens(); bool shape_matches = true; @@ -152,9 +154,9 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); - auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); - auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); - auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); m.replace_instruction(ins, y); } @@ -244,15 +246,15 @@ static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins } void rewrite_pooling::apply(module& m) const -{ +{ for(auto ins : iterator_for(m)) { if(ins->inputs().empty()) - continue; - if(rewrite_lrn and ins->name() == "lrn") - { - lower_lrn_to_pooling(m, ins); - continue; + continue; + if(rewrite_lrn and ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; } if(ins->name() != "pooling") continue; From c69157a79f94a336267336f66d0413f0a75a6c98 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 18:04:16 +0000 Subject: [PATCH 54/64] formatting --- test/verify/test_lrn.cpp | 45 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp index 38188343fe6..2eae1a202e5 100644 --- a/test/verify/test_lrn.cpp +++ b/test/verify/test_lrn.cpp @@ -22,28 +22,29 @@ * THE SOFTWARE. */ -#include "verify_program.hpp" -#include -#include -#include - -template -struct test_lrn : verify_program> -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); - mm->add_instruction( - migraphx::make_op("lrn", - {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), - x); - return p; - } -}; +#include "verify_program.hpp" +#include +#include +#include -template struct test_lrn<32, 6>; -template struct test_lrn<32, 5>; +template +struct test_lrn : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter( + "x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); + mm->add_instruction( + migraphx::make_op( + "lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), + x); + return p; + } +}; + +template struct test_lrn<32, 6>; +template struct test_lrn<32, 5>; template struct test_lrn<31, 8>; template struct test_lrn<31, 5>; From 1dadd05b30417978c45cd9f9054539739b27c56e Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 18:05:15 +0000 Subject: [PATCH 55/64] formatting --- test/rewrite_pooling_test.cpp | 140 +++++++++++++++++----------------- 1 file changed, 71 insertions(+), 69 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1f69a26cbe7..b04821bfa82 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -51,7 +51,7 @@ TEST_CASE(rewrite_pooling_test) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 4, 5}}, @@ -89,7 +89,7 @@ TEST_CASE(rewrite_pooling_dialtions_test) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0}}, {"stride", {1, 1}}, {"lengths", {2, 2}}, @@ -137,7 +137,7 @@ TEST_CASE(rewrite_pooling_dialtions_test2) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {2, 2, 2}}, @@ -213,7 +213,7 @@ TEST_CASE(rewrite_pooling_dialtions_test4) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {1, 0}}, {"stride", {1, 3}}, {"lengths", {3, 1}}, @@ -268,7 +268,7 @@ TEST_CASE(rewrite_pooling_dialtions_test5) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0}}, {"stride", {2, 3}}, {"lengths", {2, 1}}, @@ -311,71 +311,73 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(test_lower_lrn_to_pooling) +{ + migraphx::module m1; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + } + opt_pooling(m1); + + migraphx::module m2; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), + x2); + + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; + int64_t pad_right = lrn_size - 1 - pad_left; + std::vector expected_pads = {pad_left, pad_right}; + + auto avg = m2.add_instruction( + migraphx::make_op( + "pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, lrn_size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), + avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + } -TEST_CASE(test_lower_lrn_to_pooling) -{ - migraphx::module m1; - { - migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; - auto input1 = m1.add_parameter("x", input_shape); - auto lrn1 = m1.add_instruction( - migraphx::make_op("lrn", {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), - input1); - m1.add_return({lrn1}); - } - opt_pooling(m1); - - migraphx::module m2; - { - migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; - auto input2 = m2.add_parameter("x", input_shape); - - auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); - - auto transpose1 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), x2); - - int64_t lrn_size = 4; - int64_t pad_left = (lrn_size - 1) / 2; - int64_t pad_right = lrn_size - 1 - pad_left; - std::vector expected_pads = {pad_left, pad_right}; - - auto avg = m2.add_instruction( - migraphx::make_op("pooling", { - {"mode", migraphx::op::pooling_mode::average}, - {"lengths", std::vector{1, lrn_size}}, - {"stride", std::vector{1, 1}}, - {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, - {"dilations", std::vector{1, 1}}, - {"count_include_pad", true} - }), transpose1); - - auto transpose2 = m2.add_instruction( - migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), avg); - - auto k_lit = m2.add_literal(1.0f); - auto a_lit = m2.add_literal(0.0001f); - auto b_lit = m2.add_literal(0.75f); - - auto k_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); - auto a_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); - auto b_mb = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); - - auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); - auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); - auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); - auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); - - m2.add_return({y}); - } - - EXPECT(m1 == m2); + EXPECT(m1 == m2); } - TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D @@ -684,7 +686,7 @@ TEST_CASE(rewrite_avepooling_na3_test) auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::max}, + {{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 3, 5}}, @@ -713,7 +715,7 @@ TEST_CASE(literal_rewrite_pooling_test) auto* mm = p.get_main_module(); auto input = mm->add_literal(migraphx::literal(s, data)); auto ret = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 4, 5}}, From f5318f235e3cfa7b710e5a7895d666d04a5793a6 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 23:13:09 +0000 Subject: [PATCH 56/64] license --- src/include/migraphx/rewrite_pooling.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index 46105c10304..dd69bb7ca93 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From a546fd418e284df25ff8e1b440da8321b9c68aad Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 23:13:49 +0000 Subject: [PATCH 57/64] enable flag in test case --- test/rewrite_pooling_test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index b04821bfa82..0122da3f992 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -313,6 +313,8 @@ TEST_CASE(rewrite_pooling_dialtions_test5) TEST_CASE(test_lower_lrn_to_pooling) { + if(not migraphx::enabled(MIGRAPHX_REWRITE_LRN{})) + return; migraphx::module m1; { migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; From afabc81f8aca502566ffeca69298c6b60de8c5d5 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Sun, 5 Oct 2025 23:26:53 +0000 Subject: [PATCH 58/64] test case accepting flag --- test/rewrite_pooling_test.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 0122da3f992..9019a88e6fa 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -36,6 +36,7 @@ #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN); static void opt_pooling(migraphx::module& m) { migraphx::rewrite_pooling rp; @@ -325,7 +326,12 @@ TEST_CASE(test_lower_lrn_to_pooling) input1); m1.add_return({lrn1}); } - opt_pooling(m1); + + // Apply the pass directly when the flag enabled + migraphx::rewrite_pooling rp{.rewrite_lrn = true}; + migraphx::dead_code_elimination dce; + rp.apply(m1); + dce.apply(m1); migraphx::module m2; { From 8b17b1d9e9ab134926fb572ade8d4e40107fa714 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 00:58:59 +0000 Subject: [PATCH 59/64] formatting --- test/rewrite_pooling_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 9019a88e6fa..9474b5665bc 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -326,7 +326,7 @@ TEST_CASE(test_lower_lrn_to_pooling) input1); m1.add_return({lrn1}); } - + // Apply the pass directly when the flag enabled migraphx::rewrite_pooling rp{.rewrite_lrn = true}; migraphx::dead_code_elimination dce; From 2e244c5d6626606a44ba5284c3ca4ecd5d202a0c Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 01:12:29 +0000 Subject: [PATCH 60/64] formatting --- test/rewrite_pooling_test.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 9474b5665bc..770f1080533 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -52,7 +52,7 @@ TEST_CASE(rewrite_pooling_test) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 4, 5}}, @@ -90,7 +90,7 @@ TEST_CASE(rewrite_pooling_dialtions_test) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0}}, {"stride", {1, 1}}, {"lengths", {2, 2}}, @@ -138,7 +138,7 @@ TEST_CASE(rewrite_pooling_dialtions_test2) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {2, 2, 2}}, @@ -214,7 +214,7 @@ TEST_CASE(rewrite_pooling_dialtions_test4) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {1, 0}}, {"stride", {1, 3}}, {"lengths", {3, 1}}, @@ -269,7 +269,7 @@ TEST_CASE(rewrite_pooling_dialtions_test5) migraphx::module m; auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0}}, {"stride", {2, 3}}, {"lengths", {2, 1}}, @@ -322,11 +322,10 @@ TEST_CASE(test_lower_lrn_to_pooling) auto input1 = m1.add_parameter("x", input_shape); auto lrn1 = m1.add_instruction( migraphx::make_op("lrn", - {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), input1); m1.add_return({lrn1}); } - // Apply the pass directly when the flag enabled migraphx::rewrite_pooling rp{.rewrite_lrn = true}; migraphx::dead_code_elimination dce; @@ -694,7 +693,7 @@ TEST_CASE(rewrite_avepooling_na3_test) auto input = m.add_parameter("x", s); auto ret = m.add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::max}, + {{"mode", migraphx::op::pooling_mode::max}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 3, 5}}, @@ -723,7 +722,7 @@ TEST_CASE(literal_rewrite_pooling_test) auto* mm = p.get_main_module(); auto input = mm->add_literal(migraphx::literal(s, data)); auto ret = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", mode}, + {{"mode", mode}, {"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"lengths", {3, 4, 5}}, From 878b1355931a9095533dc60c58da1850e70f4df6 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 13:58:02 +0000 Subject: [PATCH 61/64] simplify code --- src/rewrite_pooling.cpp | 31 ++----------------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index f33f84e6d11..b5a9c2a3997 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -83,20 +83,10 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto transposed_shape = transpose1->get_shape(); const auto& transposed_lens = transposed_shape.lens(); - if(transposed_lens.size() != 4) - { - return; - } - int64_t channel_dim = lens[1]; std::vector calculated_pads(2); calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); - if(calculated_pads[0] < 0 or calculated_pads[1] < 0) - { - return; - } - auto avg = m.insert_instruction( ins, make_op("pooling", @@ -123,28 +113,11 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto final_shape = transpose2->get_shape(); const auto& final_lens = final_shape.lens(); - bool shape_matches = true; - if(final_lens.size() != lens.size()) - { - shape_matches = false; - } - else - { - for(size_t i = 0; i < lens.size(); ++i) - { - if(final_lens[i] != lens[i]) - { - shape_matches = false; - break; - } - } - } - - if(not shape_matches) + if(final_lens != lens) { return; } - + auto k_lit = m.add_literal(k); auto a_lit = m.add_literal(alpha); auto b_lit = m.add_literal(beta); From 04a7705b2daa13518160467e086cdc0899b07de6 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 14:07:29 +0000 Subject: [PATCH 62/64] simplify code --- src/rewrite_pooling.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index b5a9c2a3997..274a8d14bb7 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -113,11 +113,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto final_shape = transpose2->get_shape(); const auto& final_lens = final_shape.lens(); - if(final_lens != lens) - { - return; - } - + if(final_lens != lens) return; + auto k_lit = m.add_literal(k); auto a_lit = m.add_literal(alpha); auto b_lit = m.add_literal(beta); From 526e25ebe35cced4a682ce88429b743af4a7d157 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 14:19:11 +0000 Subject: [PATCH 63/64] updated the doc --- docs/reference/MIGraphX-dev-env-vars.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index f9c9f309ab3..da52ce656b8 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -221,6 +221,13 @@ Model performance tunable variables change the compilation behavior of a model. | Default: No tuning is done for composable kernels. + * - | ``MIGRAPHX_REWRITE_LRN`` + | Turns on LRN-to-pooling lowering in the rewrite_pooling pass. + + - | ``1``: Turns on LRN-to-pooling lowering. + | ``0``: Returns to default behavior. + + | Default: LRN-to-pooling lowering is turned off. Matching ********** From c7cf3ee72fb6c4fbdbc69d2d7e8fabf909464779 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Mon, 6 Oct 2025 14:50:04 +0000 Subject: [PATCH 64/64] remove flag in test and formatting --- src/rewrite_pooling.cpp | 3 ++- test/rewrite_pooling_test.cpp | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 274a8d14bb7..6edf6689397 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -113,7 +113,8 @@ static void lower_lrn_to_pooling(module& m, instruction_ref ins) auto final_shape = transpose2->get_shape(); const auto& final_lens = final_shape.lens(); - if(final_lens != lens) return; + if(final_lens != lens) + return; auto k_lit = m.add_literal(k); auto a_lit = m.add_literal(alpha); diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 770f1080533..5377dc99715 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -314,8 +314,6 @@ TEST_CASE(rewrite_pooling_dialtions_test5) TEST_CASE(test_lower_lrn_to_pooling) { - if(not migraphx::enabled(MIGRAPHX_REWRITE_LRN{})) - return; migraphx::module m1; { migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}};