From 3a688d339712100ef5a39abdaea921b2c0c542c1 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Tue, 30 Sep 2025 19:25:48 -0400 Subject: [PATCH 01/13] =?UTF-8?q?Update=20build=20and=20test=20script=20an?= =?UTF-8?q?d=20dockerfile=20to=20add=20in=20onnxrt=20pai=20laun=E2=80=A6?= =?UTF-8?q?=20(#4336)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update build and test script and dockerfile to add in onnxrt pai launcher scripts Push pai launcher scripts to Onnxrutime to let us reuse scripts from prior run Related to change from - #4321 --- Dockerfile | 2 ++ test/onnx/.onnxrt-commit | 2 +- tools/build_and_test_onnxrt.sh | 6 ++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7aecb0c685e..ecbe26345f7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -126,6 +126,8 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh +ADD tools/pai_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_test_launcher.sh +ADD tools/pai_provider_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_provider_test_launcher.sh ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index dee7f4e19f9..bdadff6093f 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -b6172c08212eb2b162dc059f324440b51792dae1 +cced33b2f5e338a4ee43cdc8695dba70d706e993 diff --git a/tools/build_and_test_onnxrt.sh b/tools/build_and_test_onnxrt.sh index 6039ae75984..351535a70b6 100755 --- a/tools/build_and_test_onnxrt.sh +++ b/tools/build_and_test_onnxrt.sh @@ -26,8 +26,10 @@ set -e ulimit -c unlimited -cp tools/pai_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_test_launcher.sh -cp tools/pai_provider_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_provider_test_launcher.sh +# Copy these over in local runs but silence them in CI +cp tools/pai_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_test_launcher.sh 2>/dev/null || : +[ -f tools/pai_provider_test_launcher.sh ] && cp tools/pai_provider_test_launcher.sh /onnxruntime/tools/ci_build/github/pai/pai_provider_test_launcher.sh + cd /onnxruntime pip3 install -r requirements-dev.txt # Add newer cmake to the path From 12dc8308d124d1b5d73f881cc0f1192abb8ea974 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 1 Oct 2025 09:01:15 -0400 Subject: [PATCH 02/13] Bump SQlite3 to 3.50.4 (#4322) --- dev-requirements.txt | 2 +- rbuild.ini | 4 ++-- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index f4f5fd75f31..aac3941855a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -ROCmSoftwarePlatform/rocm-recipes +ROCm/rocm-recipes facebook/zstd@v1.5.7 -X subdir -DCMAKE_DIR=build/cmake ccache@v4.1 -DENABLE_TESTING=OFF pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 diff --git a/rbuild.ini b/rbuild.ini index 3eb2fef6247..ef5eb2c5a8e 100644 --- a/rbuild.ini +++ b/rbuild.ini @@ -2,13 +2,13 @@ cxx = ${rocm_path}/llvm/bin/clang++ cc = ${rocm_path}/llvm/bin/clang deps = - ROCmSoftwarePlatform/rocm-recipes + ROCm/rocm-recipes -f requirements.txt [gh] ignore = danmar/cppcheck - ROCmSoftwarePlatform/rocMLIR + ROCm/rocMLIR deps = -f dev-requirements.txt oneapi-src/oneDNN@v1.7 diff --git a/requirements.txt b/requirements.txt index c02a5793d37..6f3c140d70e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,6 @@ nlohmann/json@v3.8.0 -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On +sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@c9ccbb29d3d418199d9a17b9b00ff0323d3dd69e -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From 3c4378fa70139a1759ba62d01eb744b812b3ad9e Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:53:52 -0700 Subject: [PATCH 03/13] TopK exception bugfix (#4329) --- src/targets/gpu/kernels/include/migraphx/kernels/index.hpp | 7 ++++++- test/verify/test_topk.cpp | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp index b41cb3ef026..77da7283190 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp @@ -230,7 +230,12 @@ struct index static constexpr void for_stride(index_int start, N n, Stride stride, F f) { MIGRAPHX_ASSERT(start < stride); - if constexpr(not is_integral{} and not is_integral{}) + + if constexpr(not is_integral{} and n < 1) + { + return; + } + else if constexpr(not is_integral{} and not is_integral{}) { if constexpr(max_stride_iterations(n, stride) == 1) { diff --git a/test/verify/test_topk.cpp b/test/verify/test_topk.cpp index ad5ac3f2d6a..e2d089953e5 100644 --- a/test/verify/test_topk.cpp +++ b/test/verify/test_topk.cpp @@ -72,3 +72,6 @@ template struct test_topk; template struct test_topk; template struct test_topk; template struct test_topk; + +template struct test_topk; +template struct test_topk; From 8b591436e17b25a125fb3138d364d1cb25ad8ae5 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Wed, 1 Oct 2025 12:56:33 -0500 Subject: [PATCH 04/13] Use --input-dim for specifying dynamic shapes at driver runtime (#4342) --- src/driver/main.cpp | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 31cb7472c27..4f402e06b5d 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -64,6 +64,9 @@ #include namespace { + +using dims_map = std::unordered_map>; + std::vector get_unrecognized_migraphx_envs(const char* envp[], const std::map& used_env) @@ -213,7 +216,7 @@ struct loader static auto parse_param_dims(const std::vector& param_dims_info) { - std::unordered_map> map_input_dims; + dims_map map_input_dims; std::string name = ""; for(auto&& x : param_dims_info) { @@ -502,16 +505,24 @@ struct program_params return map_load_args; } - auto generate(const program& p, const target& t, bool offload, unsigned batch) + auto generate(const program& p, + const target& t, + bool offload, + unsigned batch, + dims_map map_input_dims = {}) { parameter_map m; auto param_shapes = p.get_parameter_shapes(); std::unordered_map static_param_shapes; - std::transform( - param_shapes.cbegin(), - param_shapes.cend(), - std::inserter(static_param_shapes, static_param_shapes.end()), - [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); }); + for(auto&& param : param_shapes) + { + if(contains(map_input_dims, param.first)) + static_param_shapes[param.first] = {param.second.type(), + map_input_dims[param.first]}; + else + static_param_shapes[param.first] = param.second.to_static(batch); + } + for(auto&& s : fill0) m[s] = fill_argument(static_param_shapes.at(s), 0); for(auto&& s : fill1) @@ -591,7 +602,8 @@ struct compiler auto params(const program& p) { - return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); + return parameters.generate( + p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims)); } auto host_params(const program& p) @@ -730,7 +742,8 @@ struct verify : command std::cout << p << std::endl; auto t = c.ct.get_target(); - auto m = c.parameters.generate(p, t, true, c.l.batch); + auto m = + c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims)); if(c.to_fp16) { From 14678120b792e430f1e66476d885cd679864b8a7 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 1 Oct 2025 13:57:36 -0400 Subject: [PATCH 05/13] Add video and render groups in docker for CI (#4340) --- Jenkinsfile | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index ef9c2a49d5d..cfa98d43a4d 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,6 +39,7 @@ def rocmtestnode(Map conf) { export MIGRAPHX_GPU_DEBUG=${gpu_debug} export CXX=${compiler} export CXXFLAGS='-Werror' + rocminfo env rm -rf build mkdir build @@ -66,12 +67,18 @@ def rocmtestnode(Map conf) { checkout scm } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3').trim() + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3').trim() + def docker_opts = "--device=/dev/kfd --device=/dev/dri --cap-add SYS_PTRACE -v=${env.WORKSPACE}/../:/workspaces:rw,z" + docker_opts = docker_opts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${docker_opts}" + gitStatusWrapper(credentialsId: "${env.migraphx_ci_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'AMDMIGraphX') { withCredentials([usernamePassword(credentialsId: 'docker_test_cred', passwordVariable: 'DOCKERHUB_PASS', usernameVariable: 'DOCKERHUB_USER')]) { sh "echo $DOCKERHUB_PASS | docker login --username $DOCKERHUB_USER --password-stdin" pre() sh "docker pull ${DOCKER_IMAGE}:${env.IMAGE_TAG}" - withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=${env.WORKSPACE}/../:/workspaces:rw,z ${docker_args}") { + withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: docker_opts + docker_args) { timeout(time: 4, unit: 'HOURS') { body(cmake_build) } From 825ebca3866c6d9cb873b5af5cd16438975d16eb Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Wed, 1 Oct 2025 13:59:16 -0400 Subject: [PATCH 06/13] Fix MXFP4 bugs (#4324) * propagate_constant ignores `unpack_fp4` instructions now * `match_find_mx_quantizable_ops` from `simplify_qdq` updated to not require non-constant scales. The scales can be literals. * transpose, reshape, and broadcast instructions propagated on scale instructions when going to `quant_dot` and `quant_conv` * `raw_data` `operator<<` updated to use `fallback_visit` to also handle non-computable types --- src/include/migraphx/raw_data.hpp | 26 ++++---- src/propagate_constant.cpp | 2 +- src/simplify_qdq.cpp | 57 +++++++++++------- test/propagate_constant_test.cpp | 27 +++++++++ test/simplify_qdq_test.cpp | 98 +++++++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 34 deletions(-) diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 63f512a4948..ebd72acaab6 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -53,15 +53,15 @@ struct raw_data : raw_data_base friend Stream& operator<<(Stream& os, const Derived& d) { if(not d.empty()) - d.visit([&](auto x) { os << x; }, - [&](auto&& xs) { - for(auto&& x : xs) - { - os << "{ "; - os << x; - os << " }, "; - } - }); + d.fallback_visit([&](auto x) { os << x; }, + [&](auto&& xs) { + for(auto&& x : xs) + { + os << "{ "; + os << x; + os << " }, "; + } + }); return os; } @@ -123,9 +123,13 @@ struct raw_data : raw_data_base } else { - auto&& buffer = static_cast(*this).data(); + auto* buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - v(make_view(view_shape, reinterpret_cast(buffer))); + using byte_type = + std::conditional_t>{}, + const byte*, + byte*>; + v(make_view(view_shape, reinterpret_cast(buffer))); } } diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 50e065520a9..8b44064fa2f 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -40,7 +40,7 @@ static bool skip_propagate(instruction_ref ins) { if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name())) return skip_propagate(ins->inputs().front()); - if(ins->name() == "unpack_int4") + if(contains({"unpack_int4", "unpack_fp4"}, ins->name())) return true; auto&& s = ins->get_shape(); if(s.broadcasted() and s.element_space() < s.elements()) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index f4569fe7a36..4169964d9d0 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -49,32 +49,36 @@ std::unordered_set get_quantizable_op_names() return s; } -// Helper function to insert quantized versions of any broadcasts and transpose ops that -// occur between dequantizelinear and the quantized op -auto propagate_quantized_ins(module& m, - const instruction_ref dqins, - const instruction_ref qop_arg, - bool is_fp16_model = false) +std::vector get_between_ins(const instruction_ref dqins, + const instruction_ref qop_arg) { auto prev_ins = qop_arg; std::vector ins_between; - // matcher skips continguous, multi/broadcasts and transposes, collect all those - // instructions while(prev_ins != dqins) { ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } - auto qinp = dqins->inputs().front(); + return ins_between; +} + +// Helper function to insert quantized versions of any broadcasts and transpose ops that +// occur between dequantizelinear and the quantized op +auto propagate_quantized_ins(module& m, + const instruction_ref dqins, + instruction_ref input_ins, + std::vector ins_between, + bool is_fp16_model = false) +{ for(auto ins : reverse_iterator_for(ins_between)) { if((*ins)->name() == "convert" and is_fp16_model) { continue; } - qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp}); + input_ins = m.insert_instruction(dqins, (*ins)->get_operator(), {input_ins}); } - return qinp; + return input_ins; } struct match_find_quantizable_ops @@ -140,8 +144,13 @@ struct match_find_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); + + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); @@ -280,15 +289,13 @@ struct match_find_quantizable_ops } }; -// Note: scales are not constant b/c of dynamic quantization. // Checks for block quantized scales by checking scales are not scalar or 1D. -inline auto dynamic_block_dq(const std::string& scale) +inline auto block_dq(const std::string& scale) { // clang-format off return match::name("dequantizelinear")( match::nargs(2), match::arg(1)(match::skip_broadcasts(match::none_of( - match::is_constant(), match::scalar_shape, match::ndim(1) ).bind(scale)))); @@ -305,8 +312,8 @@ struct match_find_mx_quantizable_ops { auto matcher() const { - auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1"))); - auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2"))); + auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1"))); + auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2"))); return match::name("dot")(dq1, dq2); } @@ -328,10 +335,16 @@ struct match_find_mx_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); - qop_args.push_back(scale1); - qop_args.push_back(scale2); + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); + qop_args.push_back( + propagate_quantized_ins(m, dq1, scale1, qop_between_arg0, is_fp16_model)); + qop_args.push_back( + propagate_quantized_ins(m, dq2, scale2, qop_between_arg1, is_fp16_model)); if(qop->name() == "convolution") { diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index da7511e3482..72011fd8d83 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -535,4 +535,31 @@ TEST_CASE(block_dequantize_int4) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(pack_unpack_fp4) +{ + migraphx::shape s1{migraphx::shape::float_type, {4}}; + migraphx::shape s2{migraphx::shape::fp4x2_type, {2}}; + migraphx::module m1; + { + const std::vector vec = {1.f, 0.f, 2.f, 0.f}; + auto l = m1.add_literal(migraphx::literal(s1, vec)); + auto pack = m1.add_instruction(migraphx::make_op("pack_fp4"), l); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_fp4"), pack); + m1.add_return({unpack}); + } + + run_pass(m1); + + migraphx::module m2; + { + using migraphx::shape; + const std::vector vec = {0x2, 0x4}; + auto l = m2.add_literal(migraphx::literal(s2, vec.data())); + auto unpack = m2.add_instruction(migraphx::make_op("unpack_fp4"), l); + m2.add_return({unpack}); + } + + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index bfc3d61386b..8d548224762 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1815,6 +1815,104 @@ TEST_CASE(fp4x2_quant_dot_even) EXPECT(m1 == m2); } +TEST_CASE(fp4x2_quant_dot_trans_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 8, 24}}; + + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + auto packed_b = m1.add_parameter("weights", shape_packed_b); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto trans_b = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dq_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, trans_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_parameter("weights", shape_packed_b); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto trans_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b); + auto trans_scale_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), scale_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, trans_b, scale_a, trans_scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp4x2_quant_dot_const_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}}; + migraphx::shape shape_packed_b_gen{migraphx::shape::uint8_type, {1, 3, 24, 4}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 24, 8}}; + unsigned long seed = 826; + migraphx::literal b_lit = generate_literal(shape_packed_b_gen, seed); + migraphx::literal scale_b_lit = generate_literal(shape_scales_b, seed); + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + // avoiding visit fp4x2_type + auto packed_b = m1.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_literal(scale_b_lit); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_literal(scale_b_lit); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + // Test that unused qdq with pack_fp4, unpack_fp4 are removed TEST_CASE(fp4x2_even_remove_qdq) { From 53a7447fe8957c7dd254aec921a03c6b70028cd4 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Wed, 1 Oct 2025 19:47:44 -0400 Subject: [PATCH 07/13] Brevitas MXFP4 quantization parse (#4301) --- src/onnx/onnx_parser.cpp | 1 + src/onnx/parse_dynamicscale.cpp | 139 ++++++++++++++++++ src/onnx/parse_quantizelinear.cpp | 42 +++++- test/onnx/dynamicscale_even_test.onnx | 21 +++ test/onnx/dynamicscale_odd_test.onnx | Bin 0 -> 280 bytes test/onnx/dynamicscale_small_test.onnx | 17 +++ test/onnx/gen_onnx.py | 81 ++++++++++ test/onnx/parse/dynamicscale_test.cpp | 85 +++++++++++ .../parse/quantizelinear_mx_type_test.cpp | 78 ++++++++++ test/onnx/quantizelinear_mxfp4_even_test.onnx | 26 ++++ test/onnx/quantizelinear_mxfp4_odd_test.onnx | 26 ++++ test/onnx/verify/dynamicscale_small_test.cpp | 62 ++++++++ 12 files changed, 577 insertions(+), 1 deletion(-) create mode 100644 src/onnx/parse_dynamicscale.cpp create mode 100644 test/onnx/dynamicscale_even_test.onnx create mode 100644 test/onnx/dynamicscale_odd_test.onnx create mode 100644 test/onnx/dynamicscale_small_test.onnx create mode 100644 test/onnx/parse/dynamicscale_test.cpp create mode 100644 test/onnx/parse/quantizelinear_mx_type_test.cpp create mode 100644 test/onnx/quantizelinear_mxfp4_even_test.onnx create mode 100644 test/onnx/quantizelinear_mxfp4_odd_test.onnx create mode 100644 test/onnx/verify/dynamicscale_small_test.cpp diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index b81abe4839c..11c5b0bb307 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -831,6 +831,7 @@ shape::type_t get_type(int dtype) case 18: return shape::fp8e4m3fnuz_type; case 21: return shape::uint8_type; case 22: return shape::int8_type; + case 23: return shape::fp4x2_type; case 14: case 15: case 16: return shape::bf16_type; diff --git a/src/onnx/parse_dynamicscale.cpp b/src/onnx/parse_dynamicscale.cpp new file mode 100644 index 00000000000..a68b27b173f --- /dev/null +++ b/src/onnx/parse_dynamicscale.cpp @@ -0,0 +1,139 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +/** + * Operator from Brevitas to calculate dynamic quantization scales. + */ +struct parse_dynamicscale : op_parser +{ + + std::vector operators() const { return {{"DynamicScale"}}; }; + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + const std::vector& args) const + { + const instruction_ref input = args.front(); + instruction_ref tmp_in = input; + const auto input_lens = input->get_shape().lens(); + if(args.size() != 1) + { + MIGRAPHX_THROW("DynamicScale: must have only 1 input"); + } + int block_axis = info.attributes.at("group_dim").i(); + block_axis = tune_axis(input->get_shape().ndim(), block_axis, "DynamicScale"); + int block_size = info.attributes.at("group_size").i(); + if(block_size != 32) + { + MIGRAPHX_THROW("DynamicScale: only group_size of 32 is supported"); + } + migraphx::shape::type_t output_type = get_type(info.attributes.at("output_dtype").i()); + + // TODO expand this to handle other MX types + if(output_type != migraphx::shape::fp4x2_type) + { + MIGRAPHX_THROW("DynamicScale: only support MXFP4 type"); + } + + std::string scale_selection_method = info.attributes.at("scale_selection_method").s(); + if(scale_selection_method != "floor") + { + MIGRAPHX_THROW("DynamicScale: only support floor scale selection"); + } + + std::string zero_point_selection_method = "None"; + if(contains(info.attributes, "zero_point_selection_method")) + zero_point_selection_method = info.attributes.at("zero_point_selection_method").s(); + + if(zero_point_selection_method != "None") + { + MIGRAPHX_THROW("DynamicScale: zero_point not supported"); + } + + // make reduction axes for calculating block scales + // tmp_lens != input_lens if runt block is padded + auto tmp_lens = input_lens; + auto block_dim = tmp_lens.at(block_axis); + std::size_t block_padding = + std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim; + // handle runt block by padding + if(block_padding != 0) + { + std::vector pads_vec(2 * tmp_lens.size(), 0); + pads_vec.at(block_axis + tmp_lens.size()) = block_padding; + tmp_in = info.add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in); + tmp_lens = tmp_in->get_shape().lens(); + } + // reshape block dimension to {num_blocks, block_size} + std::size_t num_blocks = tmp_lens.at(block_axis) / std::size_t(block_size); + std::vector reduct_dims = tmp_lens; + reduct_dims.at(block_axis) = block_size; + reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks); + instruction_ref reshape_ins = + info.add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); + + // dynamic quantization for MX types: + // V_k = fp32 vector input of block size k + // B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V + // X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4 + auto abs_ins = info.add_instruction(make_op("abs"), reshape_ins); + auto reduce_max_ins = + info.add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); + auto log2_ins = info.add_instruction(make_op("log2"), reduce_max_ins); + auto floor_ins = info.add_instruction(make_op("floor"), log2_ins); + auto lit_2_ins = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2_ins = info.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = info.add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); + auto lit_4_ins = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4_ins = info.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = info.add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); + + // squeeze reduction axis for use in block quantized quantizelinear + block_scales_ins = info.add_instruction(make_op("squeeze", {{"axes", {block_axis + 1}}}), + block_scales_ins); + + return block_scales_ins; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_quantizelinear.cpp b/src/onnx/parse_quantizelinear.cpp index 92a773ae63d..03a6395308f 100644 --- a/src/onnx/parse_quantizelinear.cpp +++ b/src/onnx/parse_quantizelinear.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -91,6 +91,46 @@ struct parse_quantizelinear : op_parser args = transform_quantize_dequantize_linear_inputs( info, opd.onnx_name, block_size, axis, args); + if(output_type == migraphx::shape::fp4x2_type) + { + // Parsing in pack_fp4 and unpack_fp4 for the FP4 case + auto q_ins = info.add_instruction( + make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), args); + + // packing axis set to fastest dimension + auto quantized_shape = q_ins->get_shape(); + const auto& qs_strides = quantized_shape.strides(); + if(qs_strides.empty()) + { + MIGRAPHX_THROW("QuantizeLinear: MX type quantized_shape has no strides"); + } + int fast_axis = + std::min_element(qs_strides.cbegin(), qs_strides.cend()) - qs_strides.cbegin(); + bool odd_fast_axis = (quantized_shape.lens().at(fast_axis) % 2 == 1); + if(odd_fast_axis) + { + // pad fastest dimension by 1 if it is odd + std::vector padding(2 * quantized_shape.ndim(), 0); + padding.at(fast_axis * 2 + 1) = 1; + q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins); + } + auto pack_ins = info.add_instruction(make_op("pack_fp4", {{"axis", fast_axis}}), + q_ins); // output is fp4x2_type + auto unpack_ins = info.add_instruction(make_op("unpack_fp4", {{"axis", fast_axis}}), + pack_ins); // output is fp8e4m3fn_type + if(odd_fast_axis) + { + // slice off padded values + unpack_ins = info.add_instruction( + make_op("slice", + {{"axes", {fast_axis}}, + {"starts", {0}}, + {"ends", {quantized_shape.lens().at(fast_axis)}}}), + unpack_ins); + } + return unpack_ins; + } + if(parser.opset_version < 19) { auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type(); diff --git a/test/onnx/dynamicscale_even_test.onnx b/test/onnx/dynamicscale_even_test.onnx new file mode 100644 index 00000000000..37ade59662c --- /dev/null +++ b/test/onnx/dynamicscale_even_test.onnx @@ -0,0 +1,21 @@ + dynamicscale_even_test: + +inputoutput" DynamicScale* + group_dim* + +group_size * + output_dtype*" +scale_selection_method"floor*& +zero_point_selection_method"Nonedynamicscale_even_testZ +input + + +@ + +b +output + + +@ + +B \ No newline at end of file diff --git a/test/onnx/dynamicscale_odd_test.onnx b/test/onnx/dynamicscale_odd_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5a4bfcc7a0d3e8bb2c50138716e5e1caead2b988 GIT binary patch literal 280 zcmZ{fu?~Vj5JWu`gBuDwqmpQ0C@lB{OJibZV`pT|mO{`)2z + +TEST_CASE(dynamicscale_even_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {3, 64, 4, 4}}); + auto reduce_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 32, 4, 4}}}), input); + auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), reduce_reshape); + auto reduce_max_ins = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), abs_ins); + auto log2_ins = mm->add_instruction(migraphx::make_op("log2"), reduce_max_ins); + auto floor_ins = mm->add_instruction(migraphx::make_op("floor"), log2_ins); + auto lit_2_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = mm->add_instruction(migraphx::make_op("pow"), broadcast_lit_2, floor_ins); + auto lit_4_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = mm->add_instruction(migraphx::make_op("div"), pow_ins, broadcast_lit_4); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), block_scales_ins); + + auto prog = optimize_onnx("dynamicscale_even_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(dynamicscale_odd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {71, 5, 5}}); + auto padded_input = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 25, 0, 0}}}), input); + auto reduce_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 32, 5, 5}}}), padded_input); + auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), reduce_reshape); + auto reduce_max_ins = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), abs_ins); + auto log2_ins = mm->add_instruction(migraphx::make_op("log2"), reduce_max_ins); + auto floor_ins = mm->add_instruction(migraphx::make_op("floor"), log2_ins); + auto lit_2_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = mm->add_instruction(migraphx::make_op("pow"), broadcast_lit_2, floor_ins); + auto lit_4_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = mm->add_instruction(migraphx::make_op("div"), pow_ins, broadcast_lit_4); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), block_scales_ins); + + auto prog = optimize_onnx("dynamicscale_odd_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/quantizelinear_mx_type_test.cpp b/test/onnx/parse/quantizelinear_mx_type_test.cpp new file mode 100644 index 00000000000..4beb1395533 --- /dev/null +++ b/test/onnx/parse/quantizelinear_mx_type_test.cpp @@ -0,0 +1,78 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +// even fastest dimension +TEST_CASE(quantizelinear_mxfp4_even_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3, 64, 4, 4}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3, 2, 4, 4}}); + auto l1_reshape = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l1); + l1_reshape = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 32, 4, 4}}}), l1_reshape); + l1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 64, 4, 4}}}), l1_reshape); + auto q_ins = mm->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + l0, + l1_reshape); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), q_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + mm->add_return({unpack_ins}); + + auto prog = read_onnx("quantizelinear_mxfp4_even_test.onnx"); + EXPECT(p.sort() == prog.sort()); +} + +// odd fastest dimension +TEST_CASE(quantizelinear_mxfp4_odd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3, 64, 4, 7}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3, 2, 4, 7}}); + auto l1_reshape = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l1); + l1_reshape = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 32, 4, 7}}}), l1_reshape); + l1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 64, 4, 7}}}), l1_reshape); + auto q_ins = mm->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + l0, + l1_reshape); + auto pad_ins = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1}}}), q_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), pad_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + auto slice_ins = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {7}}}), unpack_ins); + mm->add_return({slice_ins}); + + auto prog = read_onnx("quantizelinear_mxfp4_odd_test.onnx"); + EXPECT(p.sort() == prog.sort()); +} diff --git a/test/onnx/quantizelinear_mxfp4_even_test.onnx b/test/onnx/quantizelinear_mxfp4_even_test.onnx new file mode 100644 index 00000000000..f2c2fb01b34 --- /dev/null +++ b/test/onnx/quantizelinear_mxfp4_even_test.onnx @@ -0,0 +1,26 @@ + quantizelinear_mxfp4_even_test: +P +0 +1out"QuantizeLinear* +axis* + +block_size * + output_dtypequantizelinear_mxfp4_even_testZ +0 + + +@ + +Z +1 + + + + +b +out + + +@ + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_mxfp4_odd_test.onnx b/test/onnx/quantizelinear_mxfp4_odd_test.onnx new file mode 100644 index 00000000000..52baba83ed8 --- /dev/null +++ b/test/onnx/quantizelinear_mxfp4_odd_test.onnx @@ -0,0 +1,26 @@ + quantizelinear_mxfp4_odd_test: +P +0 +1out"QuantizeLinear* +axis* + +block_size * + output_dtypequantizelinear_mxfp4_odd_testZ +0 + + +@ + +Z +1 + + + + +b +out + + +@ + +B \ No newline at end of file diff --git a/test/onnx/verify/dynamicscale_small_test.cpp b/test/onnx/verify/dynamicscale_small_test.cpp new file mode 100644 index 00000000000..70b5cc26e7e --- /dev/null +++ b/test/onnx/verify/dynamicscale_small_test.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(dynamicscale_small_test) +{ + migraphx::program p = read_onnx("dynamicscale_small_test.onnx"); + p.compile(migraphx::make_target("ref")); + std::vector input_lens{4, 4}; + auto input_type = migraphx::shape::float_type; + migraphx::shape data_shape{input_type, input_lens}; + std::vector data = {-100.f, + -12.f, + 32.f, + 819.f, + -6.f, + -5.75f, + -5.50f, + -5.25f, + -5.f, + -0.30f, + -1.40f, + -1.20f, + 2.0f, + 0.25f, + 0.33f, + 2.20f}; + migraphx::parameter_map pp; + pp["input"] = migraphx::argument(data_shape, data.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // hand calculated values + std::vector gold = {128.f, 1.f, 1.f, 0.5f}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} From b92bf17d3b66bb94caf9cae4f98962197596788e Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 1 Oct 2025 22:41:02 -0400 Subject: [PATCH 08/13] Bump onnx to 1.18.0 (#4323) --- test/py/onnx_backend_test.py | 22 ++++++++++++++++++++++ test/py/requirements-onnx.txt | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 2b395c55154..f0efd96e3d8 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -629,6 +629,25 @@ def disabled_tests_onnx_1_17_0(backend_test): backend_test.exclude(r'test_resize_upsample_sizes_nearest_not_smaller_cpu') +def disabled_tests_onnx_1_18_0(backend_test): + # src/onnx/onnx_parser.cpp:841: get_type: Prototensor data type 23 not supported + backend_test.exclude(r'test_cast_FLOAT16_to_FLOAT4E2M1_cpu') + backend_test.exclude(r'test_cast_FLOAT4E2M1_to_FLOAT16_cpu') + backend_test.exclude(r'test_cast_FLOAT4E2M1_to_FLOAT_cpu') + backend_test.exclude(r'test_cast_FLOAT_to_FLOAT4E2M1_cpu') + backend_test.exclude(r'test_dequantizelinear_float4e2m1_cpu') + backend_test.exclude(r'test_quantizelinear_float4e2m1_cpu') + # src/onnx/checks.cpp:35: check_arg_empty: PARSE_TopK: k input must be constant + backend_test.exclude(r'test_top_k_same_values_2d_cpu') + backend_test.exclude(r'test_top_k_same_values_cpu') + backend_test.exclude(r'test_top_k_same_values_largest_cpu') + backend_test.exclude(r'test_top_k_uint64_cpu') + #src/shape.cpp:367: lens: SHAPE: lens() called on a dynamic shape + backend_test.exclude(r'test_unique_length_1_cpu') + # + backend_test.exclude(r'test_averagepool_2d_ceil_last_window_starts_on_pad_cpu') + + def disabled_tests_int4(backend_test): # quantizelinear backend_test.exclude(r'test_quantizelinear_int4') @@ -1209,6 +1228,9 @@ def create_backend_test(testname=None, target_device=None): if version.parse(onnx.__version__) >= version.parse("1.17.0"): disabled_tests_onnx_1_17_0(backend_test) + if version.parse(onnx.__version__) >= version.parse("1.18.0"): + disabled_tests_onnx_1_18_0(backend_test) + # import all test cases at global scope to make # them visible to python.unittest. diff --git a/test/py/requirements-onnx.txt b/test/py/requirements-onnx.txt index d7cfa26772c..566c3db1adc 100644 --- a/test/py/requirements-onnx.txt +++ b/test/py/requirements-onnx.txt @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -onnx==1.17.0;python_version>="3.11" +onnx==1.18.0;python_version>="3.11" onnx==1.14.1;python_version<"3.11" protobuf==4.25.8 numpy==1.26.4;python_version>="3.11" From b3b6dfdc2a6d1ea5f3c15ebe59d938ec41b9363c Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 2 Oct 2025 09:16:44 -0400 Subject: [PATCH 09/13] Fuse GEMM+GEMM with rocMLIR (#4261) rocmlir supports GEMM+GEMM fusion, so MIGraphX needs to support the fusion on our side. Solves #4230 --- Jenkinsfile | 2 +- docs/reference/MIGraphX-dev-env-vars.rst | 8 + src/include/migraphx/module.hpp | 1 + src/module.cpp | 25 + src/targets/gpu/fuse_mlir.cpp | 156 +++++ test/gpu/fuse_mlir.cpp | 842 +++++++++++++++++++++++ test/module_test.cpp | 46 ++ test/verify/CMakeLists.txt | 6 +- test/verify/test_conv_add_dot.cpp | 65 ++ test/verify/test_dot_add_dot.cpp | 51 ++ 10 files changed, 1200 insertions(+), 2 deletions(-) create mode 100644 test/verify/test_conv_add_dot.cpp create mode 100644 test/verify/test_dot_add_dot.cpp diff --git a/Jenkinsfile b/Jenkinsfile index cfa98d43a4d..6dd7994ed89 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -199,7 +199,7 @@ rocmtest clang_debug: rocmnode('mi200+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_MLIR_GEG_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 266ec1fbe92..f9c9f309ab3 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -144,6 +144,14 @@ Model performance tunable variables change the compilation behavior of a model. | Default: Reduction fusions are turned off. + * - | ``MIGRAPHX_ENABLE_MLIR_GEG_FUSION`` + | Turns on GEMM+GEMM fusions in MLIR. + + - | ``1``: Turns on G+G fusions. + | ``0``: Returns to default behavior. + + | Default: GEMM+GEMM fusions are turned off. + * - | ``MIGRAPHX_MLIR_ENABLE_SPLITK`` | Turns on Split-k performance configurations during MLIR tuning. diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index ccf67de0d85..ec109d98b6d 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -339,6 +339,7 @@ struct MIGRAPHX_EXPORT module ins_dep_map calc_implicit_deps() const; void repeat_while_changes(std::size_t n, const std::function& f); + void localized_sort(instruction_ref start_ins, instruction_ref end_ins); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y); diff --git a/src/module.cpp b/src/module.cpp index 50479933940..0b94a36c5a0 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1600,6 +1600,31 @@ void module::repeat_while_changes(std::size_t n, const std::function& f) } } +// For topologically sorting a region in a module, canonically, such that the +// dependent chain between the two input instructions is last +void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) +{ + // get the chain of instructions between start_ins and end_ins, inclusive + auto fusion_ins = find_instructions_between(start_ins, end_ins, this); + + // move all instructions between start_ins & end_ins that are not in the fusion chain + // to the start_ins. In order, moving to the same destination, this will naturally preserve + // the preexisting topological order of the module + for(auto it = std::next(start_ins); it != end_ins;) + { + if(fusion_ins.count(it) == 0) + { + auto next = std::next(it); // move_instruction updates the iterator + this->move_instruction(it, start_ins); + it = next; + } + else + { + ++it; + } + } +} + bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); } std::ostream& operator<<(std::ostream& os, const module& m) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index b036fb4cb0e..83d30c50bba 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -48,6 +48,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -779,6 +780,152 @@ struct find_mlir_fused_ops } }; +/** + * Fuses rocMLIR conv/dot -> pointwise -> dot chain + * into a mlir_op with submodule. + */ +struct find_mlir_fused_geg_ops +{ + mlir_mode conv_mode = mlir_mode::none; + mlir_mode dot_mode = mlir_mode::none; + + /* + * Matches: + * mlir_dot_or_conv -> + * pointwise -> + * dot + */ + auto matcher() const + { + auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) + .bind("first_gemm_based_op"); + auto elemwise = + mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); + return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise)) + .bind("second_gemm_op"); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto second_gemm_ins = r.result; + auto elemwise_ins = r.instructions["elemwise"]; + auto first_gemm_ins = r.instructions["first_gemm_based_op"]; + + auto* elemwise_module = elemwise_ins->module_inputs().front(); + auto elemwise_inputs = elemwise_ins->inputs(); + + // only one input to elemwise should depend on first_gemm + if(std::any_of(elemwise_inputs.begin(), elemwise_inputs.end(), [&](const auto& i) { + return i != first_gemm_ins and reaches(first_gemm_ins, i); + })) + return; + + // only one input to second_gemm should depend on elemwise + auto second_gemm_inputs = second_gemm_ins->inputs(); + if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { + return i != elemwise_ins and reaches(elemwise_ins, i); + })) + return; + + std::unordered_map map_ins; + module_ref mm = + mpm.create_module("mlir_" + elemwise_ins->module_inputs().front()->name() + "_geg"); + mm->set_bypass(); + fuse_input_ops(mm, first_gemm_ins->inputs(), &map_ins); + + // need to track multi-user scenarios for both intermediates + bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1; + bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; + + // add the first gemm to the module + std::vector first_gemm_mapped_inputs; + first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size()); + std::transform(first_gemm_ins->inputs().begin(), + first_gemm_ins->inputs().end(), + std::back_inserter(first_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto first_gemm_in_module = + mm->add_instruction(first_gemm_ins->get_operator(), first_gemm_mapped_inputs); + map_ins[first_gemm_ins] = first_gemm_in_module; + + // fuse external inputs for the elemwise operation + fuse_input_ops(mm, elemwise_inputs, &map_ins); + + // fuse elemwise submodule + auto elemwise_rins = + mm->fuse(*elemwise_module, elemwise_inputs, &map_ins, &insert_pointwise); + assert(elemwise_rins.size() == 1); + map_ins[elemwise_ins] = elemwise_rins.front(); + + // fuse external inputs for the second gemm + fuse_input_ops(mm, second_gemm_inputs, &map_ins); + + // add the second gemm to the new module + std::vector second_gemm_mapped_inputs; + second_gemm_mapped_inputs.reserve(second_gemm_inputs.size()); + std::transform(second_gemm_inputs.begin(), + second_gemm_inputs.end(), + std::back_inserter(second_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto second_gemm_in_module = + mm->add_instruction(second_gemm_ins->get_operator(), second_gemm_mapped_inputs); + map_ins[second_gemm_ins] = second_gemm_in_module; + + // primary output is the last gemm, which should be the first output + std::vector return_vals; + return_vals.push_back(second_gemm_in_module); + + if(elemwise_has_multi_outs) + { + return_vals.push_back(map_ins[elemwise_ins]); + } + if(first_gemm_has_multi_outs) + { + return_vals.push_back(map_ins[first_gemm_ins]); + } + mm->add_return(return_vals); + auto inputs = find_inputs(map_ins, &mpm.get_module(), mm); + + // sort fusion section of module such that any external inputs are moved before the fusion + // so that we can safely place the fused mod in the multi-out case at the beginning of the + // chain + mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); + + auto fused_ins = + mpm.get_module().insert_instruction(first_gemm_ins, + mlir_op{second_gemm_ins->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); + + if(first_gemm_has_multi_outs or elemwise_has_multi_outs) + { + std::size_t output_idx = 0; + if(elemwise_has_multi_outs) + { + auto elemwise_result = mpm.get_module().insert_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + mpm.get_module().replace_instruction(elemwise_ins, elemwise_result); + } + if(first_gemm_has_multi_outs) + { + mpm.get_module().replace_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + } + mpm.get_module().replace_instruction( + second_gemm_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + } + else + { + // simple single output case + mpm.get_module().replace_instruction(second_gemm_ins, fused_ins); + } + } +}; + template struct find_mlir_standalone_op { @@ -1300,6 +1447,15 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_mlir_attention_op{}); mpm.run_pass(dead_code_elimination{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + { + match::find_matches( + mpm, + find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), + .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + } + match::find_matches( mpm, find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 57c401b8359..5532da8832b 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -40,6 +40,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); struct non_mlir_op { @@ -1773,6 +1774,847 @@ TEST_CASE(unpack_fp4_dot_odd) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dot_add_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2,4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {4,2} + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_square) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_mul_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {4, 5}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto mul = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), mul, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add) +{ + migraphx::shape is{migraphx::shape::float_type, {4, 14, 122, 122}}; + migraphx::shape ys{migraphx::shape::float_type, {4, 56, 122, 122}}; + migraphx::shape ws{migraphx::shape::float_type, {56, 14, 1, 1}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto fused = add_mlir(p2, + "mlir_main:pointwise0", + {x, w, y}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = + pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + return std::make_tuple(c->get_operator(), add); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add_dot) +{ + migraphx::shape is{migraphx::shape::float_type, {2, 4, 8, 8}}; + migraphx::shape ys{migraphx::shape::float_type, {2, 8, 8, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {8, 4, 1, 1}}; + migraphx::shape zs{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {x, w, y, z}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add) +// G ->optional R -> E fusion +// G has two users, one external to fusion +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto fused = + add_mlir(p2, "mlir_main:pointwise0", {a, b, c}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + return std::make_tuple(dot1->get_operator(), + std::vector{add, dot1}); + }); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_dot = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot); + mm->add_return({get_add, transpose}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 5}}; + migraphx::shape s3{migraphx::shape::float_type, {5, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_with_transpose) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto d_t = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d_t); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto d_t = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, d_t); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_two_externals) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto external_t2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, external_t1, external_t2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto external_t2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t1, external_t2}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before) +// GEG fusion has two outputs, E has external user. +// Base case for testing inputs being defined within the span +// of will-be-fused ops +// This also shows the relu being fused, since it is a unary op +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops +// This also shows the relu being fused, since it is a unary op. +// Result should be, and is, equivalent to the previous test +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) +// GEG fusion has two outputs, E has external user +// Base case for inputs being defined within the span of will-be-fused ops, including +// longer chain of logic, for both cases of input fusion. When enabled, +// the mul gets fused into the GEG fusion. +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops, including +// longer chain of logic +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_pw_multi_user_dot) +// GEG fusion has two outputs, E has external user, E is multiple elemwise ops +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto elemwise = + add_pointwise(p1, "main:pointwise0", {dot1, c, d}, [=](auto* pm, const auto& inputs) { + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("mul"), add, inputs.at(2)); + }); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), elemwise, e); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({elemwise, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d, e}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[4]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, mul}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_mul = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_mul, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add_dot) +// GEG fusion has two outputs (first G has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_both_multi_user) +// GEG fusion has three outputs (first G has external user, E has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), fused); + auto get_elemwise = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_elemwise, get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) diff --git a/test/module_test.cpp b/test/module_test.cpp index ac3079033de..270bfcd63e4 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1369,4 +1369,50 @@ TEST_CASE(pathological_dfs_graph_sort) EXPECT(is_sorted(m)); } +TEST_CASE(localized_sort) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s3); + auto d = mm->add_parameter("d", s4); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4} + auto external_relu = + add_pointwise(p, "main:pointwise1", {d}, single_pointwise("relu")); // {4, 3} + auto external_mul = + add_pointwise(p, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); // {4, 3} + auto add = add_pointwise(p, "main:pointwise0", {dot1, c}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); // {2, 3} + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + + // Perform localized sort between dot1 and dot2 + mm->localized_sort(dot1, dot2); + + // Verify the module is still topologically sorted overall + EXPECT(is_sorted(*mm)); + + // Verify external operations moved before the fusion chain + EXPECT(std::distance(mm->begin(), external_relu) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify the fusion chain ordering is preserved: dot1 < add < dot2 + EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), add)); + EXPECT(std::distance(mm->begin(), add) < std::distance(mm->begin(), dot2)); + + // Verify external_mul is before dot1 (since dot2 depends on external_mul) + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify transpose remains after dot2 + EXPECT(std::distance(mm->begin(), dot2) < std::distance(mm->begin(), transpose)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt index bd57abea883..56b0a2cf2ba 100644 --- a/test/verify/CMakeLists.txt +++ b/test/verify/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,3 +42,7 @@ foreach(SECTION general reduce rnn conv gemm) ) endif() endforeach() + +# TODO: remove this when MIGraphX disables attn-like offloading to rocMLIR +# f32 not supported on navi3x +set_tests_properties(test_verify_rnn PROPERTIES ENVIRONMENT "MIGRAPHX_ENABLE_MLIR_GEG_FUSION=0") diff --git a/test/verify/test_conv_add_dot.cpp b/test/verify/test_conv_add_dot.cpp new file mode 100644 index 00000000000..a4ba2018f7f --- /dev/null +++ b/test/verify/test_conv_add_dot.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_conv_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(bias_literal); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto bcast_bias = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_bias); + + // Create a literal for dot (matmul) with shape {4, 3, 3, 3} + std::vector bias_add_lens = bias_add->get_shape().lens(); + // The shape is {4, 3, 3, 3}, so we want a rhs shape {4, 3, 3, 3} + migraphx::shape dot_rhs_shape{DType, bias_add_lens}; + std::vector dot_rhs_data(dot_rhs_shape.elements(), 1.0f); + auto dot_rhs = mm->add_literal(migraphx::literal{dot_rhs_shape, dot_rhs_data}); + + // Matmul (dot) with same shape, so this is elementwise matmul + auto dot = mm->add_instruction(migraphx::make_op("dot"), bias_add, dot_rhs); + mm->add_return({dot}); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_conv_add_dot; +template struct test_conv_add_dot; diff --git a/test/verify/test_dot_add_dot.cpp b/test/verify/test_dot_add_dot.cpp new file mode 100644 index 00000000000..a087e60e018 --- /dev/null +++ b/test/verify/test_dot_add_dot.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_dot_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + mm->add_return({dot2}); + return p; + } +}; + +template struct test_dot_add_dot; +template struct test_dot_add_dot; From fd0ad324eb90bed5ede622e14e07db9b93aecf3c Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Mon, 6 Oct 2025 09:29:12 -0400 Subject: [PATCH 10/13] Updated SD3 example for change in optimum-onnx[onnxruntime] (#4344) --- examples/diffusion/python_stable_diffusion_3/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/diffusion/python_stable_diffusion_3/README.md b/examples/diffusion/python_stable_diffusion_3/README.md index 62219692bc8..1726c1d9829 100644 --- a/examples/diffusion/python_stable_diffusion_3/README.md +++ b/examples/diffusion/python_stable_diffusion_3/README.md @@ -17,6 +17,7 @@ python3 -m venv sd_venv Install dependencies ```bash +pip install --upgrade pip pip install -r torch_requirements.txt pip install -r requirements.txt ``` @@ -37,7 +38,7 @@ huggingface-cli login Export the models to onnx. Currently, optimum does not have the changes required in their latest release. Please install from their development branch instead. ```bash -python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git +pip install "optimum-onnx[onnxruntime]"@git+https://github.com/huggingface/optimum-onnx.git ``` Once optimum is built, use the following command to export the models: From 94bea56d4b3d9aaabc832c3079e55bda9e7004fd Mon Sep 17 00:00:00 2001 From: Aarushi Jain <142941703+aarushjain29@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:41:31 -0500 Subject: [PATCH 11/13] Lower lrn to pooling (#4294) This PR implements a lowering transformation that converts LRN operations into a series of pooling and arithmetic operations. --- docs/reference/MIGraphX-dev-env-vars.rst | 7 ++ src/include/migraphx/rewrite_pooling.hpp | 3 +- src/rewrite_pooling.cpp | 89 +++++++++++++++++++++++- src/targets/gpu/target.cpp | 3 +- test/rewrite_pooling_test.cpp | 74 ++++++++++++++++++++ test/verify/test_lrn.cpp | 50 +++++++++++++ test/verify/test_relu_lrn.cpp | 2 +- 7 files changed, 222 insertions(+), 6 deletions(-) create mode 100644 test/verify/test_lrn.cpp diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index f9c9f309ab3..da52ce656b8 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -221,6 +221,13 @@ Model performance tunable variables change the compilation behavior of a model. | Default: No tuning is done for composable kernels. + * - | ``MIGRAPHX_REWRITE_LRN`` + | Turns on LRN-to-pooling lowering in the rewrite_pooling pass. + + - | ``1``: Turns on LRN-to-pooling lowering. + | ``0``: Returns to default behavior. + + | Default: LRN-to-pooling lowering is turned off. Matching ********** diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index ebef9834786..dd69bb7ca93 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,7 @@ struct module; */ struct MIGRAPHX_EXPORT rewrite_pooling { + bool rewrite_lrn = false; std::string name() const { return "rewrite_pooling"; } void apply(module& m) const; }; diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6100043e802..6edf6689397 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,8 @@ #include #include +#include +#include #include namespace migraphx { @@ -55,6 +57,81 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + const auto& lens = xshape.lens(); + + if(lens.size() != 4 or size <= 0 or size > lens[1]) + { + return; + } + + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + std::vector perm = {0, 2, 3, 1}; + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + auto transposed_shape = transpose1->get_shape(); + const auto& transposed_lens = transposed_shape.lens(); + + int64_t channel_dim = lens[1]; + std::vector calculated_pads(2); + calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); + + auto avg = m.insert_instruction( + ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto avg_shape = avg->get_shape(); + const auto& avg_lens = avg_shape.lens(); + + if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) + { + return; + } + + std::vector inv_perm = {0, 3, 1, 2}; + auto transpose2 = + m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + + auto final_shape = transpose2->get_shape(); + const auto& final_lens = final_shape.lens(); + + if(final_lens != lens) + return; + + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); +} + static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { // TODO remove this when MIOpen supports dilated pooling @@ -143,10 +220,16 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(ins->name() != "pooling") - continue; if(ins->inputs().empty()) continue; + if(rewrite_lrn and ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + if(ins->name() != "pooling") + continue; + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); bool same_kernel_as_shape = std::equal( diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 85b3698c8ab..5844c934259 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -84,6 +84,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif @@ -203,7 +204,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{}, + rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}, dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1ccafed179e..5377dc99715 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -34,6 +34,9 @@ #include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN); static void opt_pooling(migraphx::module& m) { migraphx::rewrite_pooling rp; @@ -309,6 +312,77 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(test_lower_lrn_to_pooling) +{ + migraphx::module m1; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + } + // Apply the pass directly when the flag enabled + migraphx::rewrite_pooling rp{.rewrite_lrn = true}; + migraphx::dead_code_elimination dce; + rp.apply(m1); + dce.apply(m1); + + migraphx::module m2; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), + x2); + + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; + int64_t pad_right = lrn_size - 1 - pad_left; + std::vector expected_pads = {pad_left, pad_right}; + + auto avg = m2.add_instruction( + migraphx::make_op( + "pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, lrn_size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), + avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + } + + EXPECT(m1 == m2); +} + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp new file mode 100644 index 00000000000..2eae1a202e5 --- /dev/null +++ b/test/verify/test_lrn.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_lrn : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter( + "x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); + mm->add_instruction( + migraphx::make_op( + "lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), + x); + return p; + } +}; + +template struct test_lrn<32, 6>; +template struct test_lrn<32, 5>; +template struct test_lrn<31, 8>; +template struct test_lrn<31, 5>; diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index feec2ab21c6..d201bd77317 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 894f160227eb45bc2d21fae76a5e6ddc6ba17aa5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 6 Oct 2025 15:04:58 -0400 Subject: [PATCH 12/13] Update onnxruntime main be655f69f6c8623eebcac094626d0c8545951e6d (#4348) --- test/onnx/.onnxrt-commit | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index bdadff6093f..28f83240a4e 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -cced33b2f5e338a4ee43cdc8695dba70d706e993 +be655f69f6c8623eebcac094626d0c8545951e6d From af12163821ef82f3d04e1d1b5b7b9a9acfd3d082 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 8 Oct 2025 09:31:28 -0400 Subject: [PATCH 13/13] Update docs/reference/MIGraphX-dev-env-vars.rst --- docs/reference/MIGraphX-dev-env-vars.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index da52ce656b8..d18b1b98d83 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -222,7 +222,8 @@ Model performance tunable variables change the compilation behavior of a model. | Default: No tuning is done for composable kernels. * - | ``MIGRAPHX_REWRITE_LRN`` - | Turns on LRN-to-pooling lowering in the rewrite_pooling pass. + | Turns on LRN-to-pooling lowering in the ``rewrite_pooling`` pass. + - | ``1``: Turns on LRN-to-pooling lowering. | ``0``: Returns to default behavior.