diff --git a/Jenkinsfile b/Jenkinsfile index ef9c2a49d5d..6dd7994ed89 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,6 +39,7 @@ def rocmtestnode(Map conf) { export MIGRAPHX_GPU_DEBUG=${gpu_debug} export CXX=${compiler} export CXXFLAGS='-Werror' + rocminfo env rm -rf build mkdir build @@ -66,12 +67,18 @@ def rocmtestnode(Map conf) { checkout scm } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3').trim() + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3').trim() + def docker_opts = "--device=/dev/kfd --device=/dev/dri --cap-add SYS_PTRACE -v=${env.WORKSPACE}/../:/workspaces:rw,z" + docker_opts = docker_opts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${docker_opts}" + gitStatusWrapper(credentialsId: "${env.migraphx_ci_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'AMDMIGraphX') { withCredentials([usernamePassword(credentialsId: 'docker_test_cred', passwordVariable: 'DOCKERHUB_PASS', usernameVariable: 'DOCKERHUB_USER')]) { sh "echo $DOCKERHUB_PASS | docker login --username $DOCKERHUB_USER --password-stdin" pre() sh "docker pull ${DOCKER_IMAGE}:${env.IMAGE_TAG}" - withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=${env.WORKSPACE}/../:/workspaces:rw,z ${docker_args}") { + withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: docker_opts + docker_args) { timeout(time: 4, unit: 'HOURS') { body(cmake_build) } @@ -192,7 +199,7 @@ rocmtest clang_debug: rocmnode('mi200+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_MLIR_GEG_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/dev-requirements.txt b/dev-requirements.txt index f4f5fd75f31..aac3941855a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -ROCmSoftwarePlatform/rocm-recipes +ROCm/rocm-recipes facebook/zstd@v1.5.7 -X subdir -DCMAKE_DIR=build/cmake ccache@v4.1 -DENABLE_TESTING=OFF pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 266ec1fbe92..d18b1b98d83 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -144,6 +144,14 @@ Model performance tunable variables change the compilation behavior of a model. | Default: Reduction fusions are turned off. + * - | ``MIGRAPHX_ENABLE_MLIR_GEG_FUSION`` + | Turns on GEMM+GEMM fusions in MLIR. + + - | ``1``: Turns on G+G fusions. + | ``0``: Returns to default behavior. + + | Default: GEMM+GEMM fusions are turned off. + * - | ``MIGRAPHX_MLIR_ENABLE_SPLITK`` | Turns on Split-k performance configurations during MLIR tuning. @@ -213,6 +221,14 @@ Model performance tunable variables change the compilation behavior of a model. | Default: No tuning is done for composable kernels. + * - | ``MIGRAPHX_REWRITE_LRN`` + | Turns on LRN-to-pooling lowering in the ``rewrite_pooling`` pass. + + + - | ``1``: Turns on LRN-to-pooling lowering. + | ``0``: Returns to default behavior. + + | Default: LRN-to-pooling lowering is turned off. Matching ********** diff --git a/examples/diffusion/python_stable_diffusion_3/README.md b/examples/diffusion/python_stable_diffusion_3/README.md index 62219692bc8..1726c1d9829 100644 --- a/examples/diffusion/python_stable_diffusion_3/README.md +++ b/examples/diffusion/python_stable_diffusion_3/README.md @@ -17,6 +17,7 @@ python3 -m venv sd_venv Install dependencies ```bash +pip install --upgrade pip pip install -r torch_requirements.txt pip install -r requirements.txt ``` @@ -37,7 +38,7 @@ huggingface-cli login Export the models to onnx. Currently, optimum does not have the changes required in their latest release. Please install from their development branch instead. ```bash -python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git +pip install "optimum-onnx[onnxruntime]"@git+https://github.com/huggingface/optimum-onnx.git ``` Once optimum is built, use the following command to export the models: diff --git a/rbuild.ini b/rbuild.ini index 3eb2fef6247..ef5eb2c5a8e 100644 --- a/rbuild.ini +++ b/rbuild.ini @@ -2,13 +2,13 @@ cxx = ${rocm_path}/llvm/bin/clang++ cc = ${rocm_path}/llvm/bin/clang deps = - ROCmSoftwarePlatform/rocm-recipes + ROCm/rocm-recipes -f requirements.txt [gh] ignore = danmar/cppcheck - ROCmSoftwarePlatform/rocMLIR + ROCm/rocMLIR deps = -f dev-requirements.txt oneapi-src/oneDNN@v1.7 diff --git a/requirements.txt b/requirements.txt index 95fe65bb50c..365524c3724 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,6 @@ nlohmann/json@v3.8.0 -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On +sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@33b0fc534532f8e8cb7bec2b5f7d20a69be2def5 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 31cb7472c27..4f402e06b5d 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -64,6 +64,9 @@ #include namespace { + +using dims_map = std::unordered_map>; + std::vector get_unrecognized_migraphx_envs(const char* envp[], const std::map& used_env) @@ -213,7 +216,7 @@ struct loader static auto parse_param_dims(const std::vector& param_dims_info) { - std::unordered_map> map_input_dims; + dims_map map_input_dims; std::string name = ""; for(auto&& x : param_dims_info) { @@ -502,16 +505,24 @@ struct program_params return map_load_args; } - auto generate(const program& p, const target& t, bool offload, unsigned batch) + auto generate(const program& p, + const target& t, + bool offload, + unsigned batch, + dims_map map_input_dims = {}) { parameter_map m; auto param_shapes = p.get_parameter_shapes(); std::unordered_map static_param_shapes; - std::transform( - param_shapes.cbegin(), - param_shapes.cend(), - std::inserter(static_param_shapes, static_param_shapes.end()), - [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); }); + for(auto&& param : param_shapes) + { + if(contains(map_input_dims, param.first)) + static_param_shapes[param.first] = {param.second.type(), + map_input_dims[param.first]}; + else + static_param_shapes[param.first] = param.second.to_static(batch); + } + for(auto&& s : fill0) m[s] = fill_argument(static_param_shapes.at(s), 0); for(auto&& s : fill1) @@ -591,7 +602,8 @@ struct compiler auto params(const program& p) { - return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); + return parameters.generate( + p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims)); } auto host_params(const program& p) @@ -730,7 +742,8 @@ struct verify : command std::cout << p << std::endl; auto t = c.ct.get_target(); - auto m = c.parameters.generate(p, t, true, c.l.batch); + auto m = + c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims)); if(c.to_fp16) { diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index ccf67de0d85..ec109d98b6d 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -339,6 +339,7 @@ struct MIGRAPHX_EXPORT module ins_dep_map calc_implicit_deps() const; void repeat_while_changes(std::size_t n, const std::function& f); + void localized_sort(instruction_ref start_ins, instruction_ref end_ins); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y); diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 63f512a4948..ebd72acaab6 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -53,15 +53,15 @@ struct raw_data : raw_data_base friend Stream& operator<<(Stream& os, const Derived& d) { if(not d.empty()) - d.visit([&](auto x) { os << x; }, - [&](auto&& xs) { - for(auto&& x : xs) - { - os << "{ "; - os << x; - os << " }, "; - } - }); + d.fallback_visit([&](auto x) { os << x; }, + [&](auto&& xs) { + for(auto&& x : xs) + { + os << "{ "; + os << x; + os << " }, "; + } + }); return os; } @@ -123,9 +123,13 @@ struct raw_data : raw_data_base } else { - auto&& buffer = static_cast(*this).data(); + auto* buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - v(make_view(view_shape, reinterpret_cast(buffer))); + using byte_type = + std::conditional_t>{}, + const byte*, + byte*>; + v(make_view(view_shape, reinterpret_cast(buffer))); } } diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index ebef9834786..dd69bb7ca93 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,7 @@ struct module; */ struct MIGRAPHX_EXPORT rewrite_pooling { + bool rewrite_lrn = false; std::string name() const { return "rewrite_pooling"; } void apply(module& m) const; }; diff --git a/src/module.cpp b/src/module.cpp index 50479933940..0b94a36c5a0 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1600,6 +1600,31 @@ void module::repeat_while_changes(std::size_t n, const std::function& f) } } +// For topologically sorting a region in a module, canonically, such that the +// dependent chain between the two input instructions is last +void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) +{ + // get the chain of instructions between start_ins and end_ins, inclusive + auto fusion_ins = find_instructions_between(start_ins, end_ins, this); + + // move all instructions between start_ins & end_ins that are not in the fusion chain + // to the start_ins. In order, moving to the same destination, this will naturally preserve + // the preexisting topological order of the module + for(auto it = std::next(start_ins); it != end_ins;) + { + if(fusion_ins.count(it) == 0) + { + auto next = std::next(it); // move_instruction updates the iterator + this->move_instruction(it, start_ins); + it = next; + } + else + { + ++it; + } + } +} + bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); } std::ostream& operator<<(std::ostream& os, const module& m) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index b81abe4839c..11c5b0bb307 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -831,6 +831,7 @@ shape::type_t get_type(int dtype) case 18: return shape::fp8e4m3fnuz_type; case 21: return shape::uint8_type; case 22: return shape::int8_type; + case 23: return shape::fp4x2_type; case 14: case 15: case 16: return shape::bf16_type; diff --git a/src/onnx/parse_dynamicscale.cpp b/src/onnx/parse_dynamicscale.cpp new file mode 100644 index 00000000000..a68b27b173f --- /dev/null +++ b/src/onnx/parse_dynamicscale.cpp @@ -0,0 +1,139 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +/** + * Operator from Brevitas to calculate dynamic quantization scales. + */ +struct parse_dynamicscale : op_parser +{ + + std::vector operators() const { return {{"DynamicScale"}}; }; + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& /*parser*/, + onnx_parser::node_info info, + const std::vector& args) const + { + const instruction_ref input = args.front(); + instruction_ref tmp_in = input; + const auto input_lens = input->get_shape().lens(); + if(args.size() != 1) + { + MIGRAPHX_THROW("DynamicScale: must have only 1 input"); + } + int block_axis = info.attributes.at("group_dim").i(); + block_axis = tune_axis(input->get_shape().ndim(), block_axis, "DynamicScale"); + int block_size = info.attributes.at("group_size").i(); + if(block_size != 32) + { + MIGRAPHX_THROW("DynamicScale: only group_size of 32 is supported"); + } + migraphx::shape::type_t output_type = get_type(info.attributes.at("output_dtype").i()); + + // TODO expand this to handle other MX types + if(output_type != migraphx::shape::fp4x2_type) + { + MIGRAPHX_THROW("DynamicScale: only support MXFP4 type"); + } + + std::string scale_selection_method = info.attributes.at("scale_selection_method").s(); + if(scale_selection_method != "floor") + { + MIGRAPHX_THROW("DynamicScale: only support floor scale selection"); + } + + std::string zero_point_selection_method = "None"; + if(contains(info.attributes, "zero_point_selection_method")) + zero_point_selection_method = info.attributes.at("zero_point_selection_method").s(); + + if(zero_point_selection_method != "None") + { + MIGRAPHX_THROW("DynamicScale: zero_point not supported"); + } + + // make reduction axes for calculating block scales + // tmp_lens != input_lens if runt block is padded + auto tmp_lens = input_lens; + auto block_dim = tmp_lens.at(block_axis); + std::size_t block_padding = + std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim; + // handle runt block by padding + if(block_padding != 0) + { + std::vector pads_vec(2 * tmp_lens.size(), 0); + pads_vec.at(block_axis + tmp_lens.size()) = block_padding; + tmp_in = info.add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in); + tmp_lens = tmp_in->get_shape().lens(); + } + // reshape block dimension to {num_blocks, block_size} + std::size_t num_blocks = tmp_lens.at(block_axis) / std::size_t(block_size); + std::vector reduct_dims = tmp_lens; + reduct_dims.at(block_axis) = block_size; + reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks); + instruction_ref reshape_ins = + info.add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); + + // dynamic quantization for MX types: + // V_k = fp32 vector input of block size k + // B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V + // X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4 + auto abs_ins = info.add_instruction(make_op("abs"), reshape_ins); + auto reduce_max_ins = + info.add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); + auto log2_ins = info.add_instruction(make_op("log2"), reduce_max_ins); + auto floor_ins = info.add_instruction(make_op("floor"), log2_ins); + auto lit_2_ins = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2_ins = info.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = info.add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); + auto lit_4_ins = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4_ins = info.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = info.add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); + + // squeeze reduction axis for use in block quantized quantizelinear + block_scales_ins = info.add_instruction(make_op("squeeze", {{"axes", {block_axis + 1}}}), + block_scales_ins); + + return block_scales_ins; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/onnx/parse_quantizelinear.cpp b/src/onnx/parse_quantizelinear.cpp index 92a773ae63d..03a6395308f 100644 --- a/src/onnx/parse_quantizelinear.cpp +++ b/src/onnx/parse_quantizelinear.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -91,6 +91,46 @@ struct parse_quantizelinear : op_parser args = transform_quantize_dequantize_linear_inputs( info, opd.onnx_name, block_size, axis, args); + if(output_type == migraphx::shape::fp4x2_type) + { + // Parsing in pack_fp4 and unpack_fp4 for the FP4 case + auto q_ins = info.add_instruction( + make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), args); + + // packing axis set to fastest dimension + auto quantized_shape = q_ins->get_shape(); + const auto& qs_strides = quantized_shape.strides(); + if(qs_strides.empty()) + { + MIGRAPHX_THROW("QuantizeLinear: MX type quantized_shape has no strides"); + } + int fast_axis = + std::min_element(qs_strides.cbegin(), qs_strides.cend()) - qs_strides.cbegin(); + bool odd_fast_axis = (quantized_shape.lens().at(fast_axis) % 2 == 1); + if(odd_fast_axis) + { + // pad fastest dimension by 1 if it is odd + std::vector padding(2 * quantized_shape.ndim(), 0); + padding.at(fast_axis * 2 + 1) = 1; + q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins); + } + auto pack_ins = info.add_instruction(make_op("pack_fp4", {{"axis", fast_axis}}), + q_ins); // output is fp4x2_type + auto unpack_ins = info.add_instruction(make_op("unpack_fp4", {{"axis", fast_axis}}), + pack_ins); // output is fp8e4m3fn_type + if(odd_fast_axis) + { + // slice off padded values + unpack_ins = info.add_instruction( + make_op("slice", + {{"axes", {fast_axis}}, + {"starts", {0}}, + {"ends", {quantized_shape.lens().at(fast_axis)}}}), + unpack_ins); + } + return unpack_ins; + } + if(parser.opset_version < 19) { auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type(); diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 50e065520a9..8b44064fa2f 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -40,7 +40,7 @@ static bool skip_propagate(instruction_ref ins) { if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name())) return skip_propagate(ins->inputs().front()); - if(ins->name() == "unpack_int4") + if(contains({"unpack_int4", "unpack_fp4"}, ins->name())) return true; auto&& s = ins->get_shape(); if(s.broadcasted() and s.element_space() < s.elements()) diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 6100043e802..6edf6689397 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,8 @@ #include #include +#include +#include #include namespace migraphx { @@ -55,6 +57,81 @@ static void replace_with_reduce(module& m, instruction_ref ins) } } +static void lower_lrn_to_pooling(module& m, instruction_ref ins) +{ + auto v = ins->get_operator().to_value(); + + float alpha = v.at("alpha").to(); + float beta = v.at("beta").to(); + float k = v.at("bias").to(); + int size = v.at("size").to(); + + auto x = ins->inputs().at(0); + const auto& xshape = x->get_shape(); + const auto& lens = xshape.lens(); + + if(lens.size() != 4 or size <= 0 or size > lens[1]) + { + return; + } + + auto x2 = m.insert_instruction(ins, make_op("mul"), x, x); + + std::vector perm = {0, 2, 3, 1}; + auto transpose1 = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2); + + auto transposed_shape = transpose1->get_shape(); + const auto& transposed_lens = transposed_shape.lens(); + + int64_t channel_dim = lens[1]; + std::vector calculated_pads(2); + calculate_padding(0, calculated_pads, channel_dim, 1, 1, size, true); + + auto avg = m.insert_instruction( + ins, + make_op("pooling", + {{"mode", op::pooling_mode::average}, + {"lengths", std::vector{1, size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, calculated_pads[0], 0, calculated_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto avg_shape = avg->get_shape(); + const auto& avg_lens = avg_shape.lens(); + + if(avg_lens.size() != 4 or avg_lens[3] != transposed_lens[3]) + { + return; + } + + std::vector inv_perm = {0, 3, 1, 2}; + auto transpose2 = + m.insert_instruction(ins, make_op("transpose", {{"permutation", inv_perm}}), avg); + + auto final_shape = transpose2->get_shape(); + const auto& final_lens = final_shape.lens(); + + if(final_lens != lens) + return; + + auto k_lit = m.add_literal(k); + auto a_lit = m.add_literal(alpha); + auto b_lit = m.add_literal(beta); + + auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit); + auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit); + auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit); + + auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, transpose2); + auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg); + auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb); + auto y = m.insert_instruction(ins, make_op("div"), x, denpow); + + m.replace_instruction(ins, y); +} + static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) { // TODO remove this when MIOpen supports dilated pooling @@ -143,10 +220,16 @@ void rewrite_pooling::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(ins->name() != "pooling") - continue; if(ins->inputs().empty()) continue; + if(rewrite_lrn and ins->name() == "lrn") + { + lower_lrn_to_pooling(m, ins); + continue; + } + if(ins->name() != "pooling") + continue; + auto&& s = ins->inputs().front()->get_shape(); auto&& op = any_cast(ins->get_operator()); bool same_kernel_as_shape = std::equal( diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index f4569fe7a36..4169964d9d0 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -49,32 +49,36 @@ std::unordered_set get_quantizable_op_names() return s; } -// Helper function to insert quantized versions of any broadcasts and transpose ops that -// occur between dequantizelinear and the quantized op -auto propagate_quantized_ins(module& m, - const instruction_ref dqins, - const instruction_ref qop_arg, - bool is_fp16_model = false) +std::vector get_between_ins(const instruction_ref dqins, + const instruction_ref qop_arg) { auto prev_ins = qop_arg; std::vector ins_between; - // matcher skips continguous, multi/broadcasts and transposes, collect all those - // instructions while(prev_ins != dqins) { ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } - auto qinp = dqins->inputs().front(); + return ins_between; +} + +// Helper function to insert quantized versions of any broadcasts and transpose ops that +// occur between dequantizelinear and the quantized op +auto propagate_quantized_ins(module& m, + const instruction_ref dqins, + instruction_ref input_ins, + std::vector ins_between, + bool is_fp16_model = false) +{ for(auto ins : reverse_iterator_for(ins_between)) { if((*ins)->name() == "convert" and is_fp16_model) { continue; } - qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp}); + input_ins = m.insert_instruction(dqins, (*ins)->get_operator(), {input_ins}); } - return qinp; + return input_ins; } struct match_find_quantizable_ops @@ -140,8 +144,13 @@ struct match_find_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); + + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); @@ -280,15 +289,13 @@ struct match_find_quantizable_ops } }; -// Note: scales are not constant b/c of dynamic quantization. // Checks for block quantized scales by checking scales are not scalar or 1D. -inline auto dynamic_block_dq(const std::string& scale) +inline auto block_dq(const std::string& scale) { // clang-format off return match::name("dequantizelinear")( match::nargs(2), match::arg(1)(match::skip_broadcasts(match::none_of( - match::is_constant(), match::scalar_shape, match::ndim(1) ).bind(scale)))); @@ -305,8 +312,8 @@ struct match_find_mx_quantizable_ops { auto matcher() const { - auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1"))); - auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2"))); + auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1"))); + auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2"))); return match::name("dot")(dq1, dq2); } @@ -328,10 +335,16 @@ struct match_find_mx_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); - qop_args.push_back(scale1); - qop_args.push_back(scale2); + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); + qop_args.push_back( + propagate_quantized_ins(m, dq1, scale1, qop_between_arg0, is_fp16_model)); + qop_args.push_back( + propagate_quantized_ins(m, dq2, scale2, qop_between_arg1, is_fp16_model)); if(qop->name() == "convolution") { diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index b036fb4cb0e..83d30c50bba 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -48,6 +48,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -779,6 +780,152 @@ struct find_mlir_fused_ops } }; +/** + * Fuses rocMLIR conv/dot -> pointwise -> dot chain + * into a mlir_op with submodule. + */ +struct find_mlir_fused_geg_ops +{ + mlir_mode conv_mode = mlir_mode::none; + mlir_mode dot_mode = mlir_mode::none; + + /* + * Matches: + * mlir_dot_or_conv -> + * pointwise -> + * dot + */ + auto matcher() const + { + auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) + .bind("first_gemm_based_op"); + auto elemwise = + mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); + return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise)) + .bind("second_gemm_op"); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto second_gemm_ins = r.result; + auto elemwise_ins = r.instructions["elemwise"]; + auto first_gemm_ins = r.instructions["first_gemm_based_op"]; + + auto* elemwise_module = elemwise_ins->module_inputs().front(); + auto elemwise_inputs = elemwise_ins->inputs(); + + // only one input to elemwise should depend on first_gemm + if(std::any_of(elemwise_inputs.begin(), elemwise_inputs.end(), [&](const auto& i) { + return i != first_gemm_ins and reaches(first_gemm_ins, i); + })) + return; + + // only one input to second_gemm should depend on elemwise + auto second_gemm_inputs = second_gemm_ins->inputs(); + if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { + return i != elemwise_ins and reaches(elemwise_ins, i); + })) + return; + + std::unordered_map map_ins; + module_ref mm = + mpm.create_module("mlir_" + elemwise_ins->module_inputs().front()->name() + "_geg"); + mm->set_bypass(); + fuse_input_ops(mm, first_gemm_ins->inputs(), &map_ins); + + // need to track multi-user scenarios for both intermediates + bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1; + bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; + + // add the first gemm to the module + std::vector first_gemm_mapped_inputs; + first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size()); + std::transform(first_gemm_ins->inputs().begin(), + first_gemm_ins->inputs().end(), + std::back_inserter(first_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto first_gemm_in_module = + mm->add_instruction(first_gemm_ins->get_operator(), first_gemm_mapped_inputs); + map_ins[first_gemm_ins] = first_gemm_in_module; + + // fuse external inputs for the elemwise operation + fuse_input_ops(mm, elemwise_inputs, &map_ins); + + // fuse elemwise submodule + auto elemwise_rins = + mm->fuse(*elemwise_module, elemwise_inputs, &map_ins, &insert_pointwise); + assert(elemwise_rins.size() == 1); + map_ins[elemwise_ins] = elemwise_rins.front(); + + // fuse external inputs for the second gemm + fuse_input_ops(mm, second_gemm_inputs, &map_ins); + + // add the second gemm to the new module + std::vector second_gemm_mapped_inputs; + second_gemm_mapped_inputs.reserve(second_gemm_inputs.size()); + std::transform(second_gemm_inputs.begin(), + second_gemm_inputs.end(), + std::back_inserter(second_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto second_gemm_in_module = + mm->add_instruction(second_gemm_ins->get_operator(), second_gemm_mapped_inputs); + map_ins[second_gemm_ins] = second_gemm_in_module; + + // primary output is the last gemm, which should be the first output + std::vector return_vals; + return_vals.push_back(second_gemm_in_module); + + if(elemwise_has_multi_outs) + { + return_vals.push_back(map_ins[elemwise_ins]); + } + if(first_gemm_has_multi_outs) + { + return_vals.push_back(map_ins[first_gemm_ins]); + } + mm->add_return(return_vals); + auto inputs = find_inputs(map_ins, &mpm.get_module(), mm); + + // sort fusion section of module such that any external inputs are moved before the fusion + // so that we can safely place the fused mod in the multi-out case at the beginning of the + // chain + mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); + + auto fused_ins = + mpm.get_module().insert_instruction(first_gemm_ins, + mlir_op{second_gemm_ins->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); + + if(first_gemm_has_multi_outs or elemwise_has_multi_outs) + { + std::size_t output_idx = 0; + if(elemwise_has_multi_outs) + { + auto elemwise_result = mpm.get_module().insert_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + mpm.get_module().replace_instruction(elemwise_ins, elemwise_result); + } + if(first_gemm_has_multi_outs) + { + mpm.get_module().replace_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + } + mpm.get_module().replace_instruction( + second_gemm_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + } + else + { + // simple single output case + mpm.get_module().replace_instruction(second_gemm_ins, fused_ins); + } + } +}; + template struct find_mlir_standalone_op { @@ -1300,6 +1447,15 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_mlir_attention_op{}); mpm.run_pass(dead_code_elimination{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + { + match::find_matches( + mpm, + find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), + .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + } + match::find_matches( mpm, find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp index b41cb3ef026..77da7283190 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/index.hpp @@ -230,7 +230,12 @@ struct index static constexpr void for_stride(index_int start, N n, Stride stride, F f) { MIGRAPHX_ASSERT(start < stride); - if constexpr(not is_integral{} and not is_integral{}) + + if constexpr(not is_integral{} and n < 1) + { + return; + } + else if constexpr(not is_integral{} and not is_integral{}) { if constexpr(max_stride_iterations(n, stride) == 1) { diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 85b3698c8ab..5844c934259 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -84,6 +84,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif @@ -203,7 +204,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{}, + rewrite_pooling{.rewrite_lrn = enabled(MIGRAPHX_REWRITE_LRN{})}, dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 57c401b8359..5532da8832b 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -40,6 +40,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); struct non_mlir_op { @@ -1773,6 +1774,847 @@ TEST_CASE(unpack_fp4_dot_odd) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dot_add_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2,4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {4,2} + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_square) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_mul_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {4, 5}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto mul = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), mul, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add) +{ + migraphx::shape is{migraphx::shape::float_type, {4, 14, 122, 122}}; + migraphx::shape ys{migraphx::shape::float_type, {4, 56, 122, 122}}; + migraphx::shape ws{migraphx::shape::float_type, {56, 14, 1, 1}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto fused = add_mlir(p2, + "mlir_main:pointwise0", + {x, w, y}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = + pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + return std::make_tuple(c->get_operator(), add); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add_dot) +{ + migraphx::shape is{migraphx::shape::float_type, {2, 4, 8, 8}}; + migraphx::shape ys{migraphx::shape::float_type, {2, 8, 8, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {8, 4, 1, 1}}; + migraphx::shape zs{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {x, w, y, z}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add) +// G ->optional R -> E fusion +// G has two users, one external to fusion +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto fused = + add_mlir(p2, "mlir_main:pointwise0", {a, b, c}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + return std::make_tuple(dot1->get_operator(), + std::vector{add, dot1}); + }); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_dot = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot); + mm->add_return({get_add, transpose}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 5}}; + migraphx::shape s3{migraphx::shape::float_type, {5, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_with_transpose) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto d_t = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d_t); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto d_t = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, d_t); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_two_externals) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto external_t2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, external_t1, external_t2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto external_t2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t1, external_t2}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before) +// GEG fusion has two outputs, E has external user. +// Base case for testing inputs being defined within the span +// of will-be-fused ops +// This also shows the relu being fused, since it is a unary op +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops +// This also shows the relu being fused, since it is a unary op. +// Result should be, and is, equivalent to the previous test +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) +// GEG fusion has two outputs, E has external user +// Base case for inputs being defined within the span of will-be-fused ops, including +// longer chain of logic, for both cases of input fusion. When enabled, +// the mul gets fused into the GEG fusion. +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops, including +// longer chain of logic +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_pw_multi_user_dot) +// GEG fusion has two outputs, E has external user, E is multiple elemwise ops +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto elemwise = + add_pointwise(p1, "main:pointwise0", {dot1, c, d}, [=](auto* pm, const auto& inputs) { + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("mul"), add, inputs.at(2)); + }); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), elemwise, e); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({elemwise, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d, e}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[4]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, mul}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_mul = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_mul, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add_dot) +// GEG fusion has two outputs (first G has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_both_multi_user) +// GEG fusion has three outputs (first G has external user, E has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), fused); + auto get_elemwise = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_elemwise, get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) diff --git a/test/module_test.cpp b/test/module_test.cpp index ac3079033de..270bfcd63e4 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1369,4 +1369,50 @@ TEST_CASE(pathological_dfs_graph_sort) EXPECT(is_sorted(m)); } +TEST_CASE(localized_sort) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s3); + auto d = mm->add_parameter("d", s4); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4} + auto external_relu = + add_pointwise(p, "main:pointwise1", {d}, single_pointwise("relu")); // {4, 3} + auto external_mul = + add_pointwise(p, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); // {4, 3} + auto add = add_pointwise(p, "main:pointwise0", {dot1, c}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); // {2, 3} + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + + // Perform localized sort between dot1 and dot2 + mm->localized_sort(dot1, dot2); + + // Verify the module is still topologically sorted overall + EXPECT(is_sorted(*mm)); + + // Verify external operations moved before the fusion chain + EXPECT(std::distance(mm->begin(), external_relu) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify the fusion chain ordering is preserved: dot1 < add < dot2 + EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), add)); + EXPECT(std::distance(mm->begin(), add) < std::distance(mm->begin(), dot2)); + + // Verify external_mul is before dot1 (since dot2 depends on external_mul) + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify transpose remains after dot2 + EXPECT(std::distance(mm->begin(), dot2) < std::distance(mm->begin(), transpose)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/onnx/dynamicscale_even_test.onnx b/test/onnx/dynamicscale_even_test.onnx new file mode 100644 index 00000000000..37ade59662c --- /dev/null +++ b/test/onnx/dynamicscale_even_test.onnx @@ -0,0 +1,21 @@ + dynamicscale_even_test: + +inputoutput" DynamicScale* + group_dim* + +group_size * + output_dtype*" +scale_selection_method"floor*& +zero_point_selection_method"Nonedynamicscale_even_testZ +input + + +@ + +b +output + + +@ + +B \ No newline at end of file diff --git a/test/onnx/dynamicscale_odd_test.onnx b/test/onnx/dynamicscale_odd_test.onnx new file mode 100644 index 00000000000..5a4bfcc7a0d Binary files /dev/null and b/test/onnx/dynamicscale_odd_test.onnx differ diff --git a/test/onnx/dynamicscale_small_test.onnx b/test/onnx/dynamicscale_small_test.onnx new file mode 100644 index 00000000000..1e968936967 --- /dev/null +++ b/test/onnx/dynamicscale_small_test.onnx @@ -0,0 +1,17 @@ + dynamicscale_small_test: + +inputoutput" DynamicScale* + group_dim* + +group_size * + output_dtype*" +scale_selection_method"floor*& +zero_point_selection_method"Nonedynamicscale_small_testZ +input +  + +b +output +  + +B \ No newline at end of file diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 5941a4a98fc..fa83962d291 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -9735,6 +9735,51 @@ def mxfixneuron_small_test(): return ([node], [in_tv], [out_tv]) +@onnx_test() +def dynamicscale_even_test(): + in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [3, 64, 4, 4]) + out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [3, 64, 4, 4]) + node = onnx.helper.make_node('DynamicScale', + inputs=['input'], + group_dim=1, + group_size=32, + output_dtype=23, + scale_selection_method='floor', + zero_point_selection_method='None', + outputs=['output']) + return ([node], [in_tv], [out_tv]) + + +@onnx_test() +def dynamicscale_odd_test(): + in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [71, 5, 5]) + out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [71, 5, 5]) + node = onnx.helper.make_node('DynamicScale', + inputs=['input'], + group_dim=0, + group_size=32, + output_dtype=23, + scale_selection_method='floor', + zero_point_selection_method='None', + outputs=['output']) + return ([node], [in_tv], [out_tv]) + + +@onnx_test() +def dynamicscale_small_test(): + in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 4]) + out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 4]) + node = onnx.helper.make_node('DynamicScale', + inputs=['input'], + group_dim=-1, + group_size=32, + output_dtype=23, + scale_selection_method='floor', + zero_point_selection_method='None', + outputs=['output']) + return ([node], [in_tv], [out_tv]) + + @onnx_test() def neg_test(): x = helper.make_tensor_value_info('0', TensorProto.INT64, [2, 3]) @@ -11418,6 +11463,42 @@ def quantizelinear_neg_axis_test(): return make_quantizelinear_axis_graph(-2) +@onnx_test() +def quantizelinear_mxfp4_even_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 64, 4, 4]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 4, 4]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT4E2M1, [3, 64, 4, 4]) + + node = onnx.helper.make_node( + 'QuantizeLinear', + inputs = ['0', '1'], + axis = 1, + block_size = 32, + output_dtype = 23, + outputs = ['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + +@onnx_test() +def quantizelinear_mxfp4_odd_test(): + arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 64, 4, 7]) + arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 4, 7]) + arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT4E2M1, [3, 64, 4, 7]) + + node = onnx.helper.make_node( + 'QuantizeLinear', + inputs = ['0', '1'], + axis = 1, + block_size = 32, + output_dtype = 23, + outputs = ['out'], + ) + + return ([node], [arg0, arg1], [arg_out]) + + + @onnx_test() def randomnormal_test(): dtype = 11 diff --git a/test/onnx/parse/dynamicscale_test.cpp b/test/onnx/parse/dynamicscale_test.cpp new file mode 100644 index 00000000000..a617e0b2563 --- /dev/null +++ b/test/onnx/parse/dynamicscale_test.cpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(dynamicscale_even_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {3, 64, 4, 4}}); + auto reduce_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 2, 32, 4, 4}}}), input); + auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), reduce_reshape); + auto reduce_max_ins = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), abs_ins); + auto log2_ins = mm->add_instruction(migraphx::make_op("log2"), reduce_max_ins); + auto floor_ins = mm->add_instruction(migraphx::make_op("floor"), log2_ins); + auto lit_2_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = mm->add_instruction(migraphx::make_op("pow"), broadcast_lit_2, floor_ins); + auto lit_4_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = mm->add_instruction(migraphx::make_op("div"), pow_ins, broadcast_lit_4); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), block_scales_ins); + + auto prog = optimize_onnx("dynamicscale_even_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(dynamicscale_odd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {71, 5, 5}}); + auto padded_input = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 25, 0, 0}}}), input); + auto reduce_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 32, 5, 5}}}), padded_input); + auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), reduce_reshape); + auto reduce_max_ins = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), abs_ins); + auto log2_ins = mm->add_instruction(migraphx::make_op("log2"), reduce_max_ins); + auto floor_ins = mm->add_instruction(migraphx::make_op("floor"), log2_ins); + auto lit_2_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = mm->add_instruction(migraphx::make_op("pow"), broadcast_lit_2, floor_ins); + auto lit_4_ins = mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = mm->add_instruction(migraphx::make_op("div"), pow_ins, broadcast_lit_4); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), block_scales_ins); + + auto prog = optimize_onnx("dynamicscale_odd_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/quantizelinear_mx_type_test.cpp b/test/onnx/parse/quantizelinear_mx_type_test.cpp new file mode 100644 index 00000000000..4beb1395533 --- /dev/null +++ b/test/onnx/parse/quantizelinear_mx_type_test.cpp @@ -0,0 +1,78 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +// even fastest dimension +TEST_CASE(quantizelinear_mxfp4_even_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3, 64, 4, 4}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3, 2, 4, 4}}); + auto l1_reshape = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l1); + l1_reshape = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 32, 4, 4}}}), l1_reshape); + l1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 64, 4, 4}}}), l1_reshape); + auto q_ins = mm->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + l0, + l1_reshape); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), q_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + mm->add_return({unpack_ins}); + + auto prog = read_onnx("quantizelinear_mxfp4_even_test.onnx"); + EXPECT(p.sort() == prog.sort()); +} + +// odd fastest dimension +TEST_CASE(quantizelinear_mxfp4_odd_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3, 64, 4, 7}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {3, 2, 4, 7}}); + auto l1_reshape = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l1); + l1_reshape = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 32, 4, 7}}}), l1_reshape); + l1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 64, 4, 7}}}), l1_reshape); + auto q_ins = mm->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + l0, + l1_reshape); + auto pad_ins = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1}}}), q_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), pad_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + auto slice_ins = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {7}}}), unpack_ins); + mm->add_return({slice_ins}); + + auto prog = read_onnx("quantizelinear_mxfp4_odd_test.onnx"); + EXPECT(p.sort() == prog.sort()); +} diff --git a/test/onnx/quantizelinear_mxfp4_even_test.onnx b/test/onnx/quantizelinear_mxfp4_even_test.onnx new file mode 100644 index 00000000000..f2c2fb01b34 --- /dev/null +++ b/test/onnx/quantizelinear_mxfp4_even_test.onnx @@ -0,0 +1,26 @@ + quantizelinear_mxfp4_even_test: +P +0 +1out"QuantizeLinear* +axis* + +block_size * + output_dtypequantizelinear_mxfp4_even_testZ +0 + + +@ + +Z +1 + + + + +b +out + + +@ + +B \ No newline at end of file diff --git a/test/onnx/quantizelinear_mxfp4_odd_test.onnx b/test/onnx/quantizelinear_mxfp4_odd_test.onnx new file mode 100644 index 00000000000..52baba83ed8 --- /dev/null +++ b/test/onnx/quantizelinear_mxfp4_odd_test.onnx @@ -0,0 +1,26 @@ + quantizelinear_mxfp4_odd_test: +P +0 +1out"QuantizeLinear* +axis* + +block_size * + output_dtypequantizelinear_mxfp4_odd_testZ +0 + + +@ + +Z +1 + + + + +b +out + + +@ + +B \ No newline at end of file diff --git a/test/onnx/verify/dynamicscale_small_test.cpp b/test/onnx/verify/dynamicscale_small_test.cpp new file mode 100644 index 00000000000..70b5cc26e7e --- /dev/null +++ b/test/onnx/verify/dynamicscale_small_test.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(dynamicscale_small_test) +{ + migraphx::program p = read_onnx("dynamicscale_small_test.onnx"); + p.compile(migraphx::make_target("ref")); + std::vector input_lens{4, 4}; + auto input_type = migraphx::shape::float_type; + migraphx::shape data_shape{input_type, input_lens}; + std::vector data = {-100.f, + -12.f, + 32.f, + 819.f, + -6.f, + -5.75f, + -5.50f, + -5.25f, + -5.f, + -0.30f, + -1.40f, + -1.20f, + 2.0f, + 0.25f, + 0.33f, + 2.20f}; + migraphx::parameter_map pp; + pp["input"] = migraphx::argument(data_shape, data.data()); + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + // hand calculated values + std::vector gold = {128.f, 1.f, 1.f, 0.5f}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index da7511e3482..72011fd8d83 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -535,4 +535,31 @@ TEST_CASE(block_dequantize_int4) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(pack_unpack_fp4) +{ + migraphx::shape s1{migraphx::shape::float_type, {4}}; + migraphx::shape s2{migraphx::shape::fp4x2_type, {2}}; + migraphx::module m1; + { + const std::vector vec = {1.f, 0.f, 2.f, 0.f}; + auto l = m1.add_literal(migraphx::literal(s1, vec)); + auto pack = m1.add_instruction(migraphx::make_op("pack_fp4"), l); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_fp4"), pack); + m1.add_return({unpack}); + } + + run_pass(m1); + + migraphx::module m2; + { + using migraphx::shape; + const std::vector vec = {0x2, 0x4}; + auto l = m2.add_literal(migraphx::literal(s2, vec.data())); + auto unpack = m2.add_instruction(migraphx::make_op("unpack_fp4"), l); + m2.add_return({unpack}); + } + + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 2b395c55154..f0efd96e3d8 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -629,6 +629,25 @@ def disabled_tests_onnx_1_17_0(backend_test): backend_test.exclude(r'test_resize_upsample_sizes_nearest_not_smaller_cpu') +def disabled_tests_onnx_1_18_0(backend_test): + # src/onnx/onnx_parser.cpp:841: get_type: Prototensor data type 23 not supported + backend_test.exclude(r'test_cast_FLOAT16_to_FLOAT4E2M1_cpu') + backend_test.exclude(r'test_cast_FLOAT4E2M1_to_FLOAT16_cpu') + backend_test.exclude(r'test_cast_FLOAT4E2M1_to_FLOAT_cpu') + backend_test.exclude(r'test_cast_FLOAT_to_FLOAT4E2M1_cpu') + backend_test.exclude(r'test_dequantizelinear_float4e2m1_cpu') + backend_test.exclude(r'test_quantizelinear_float4e2m1_cpu') + # src/onnx/checks.cpp:35: check_arg_empty: PARSE_TopK: k input must be constant + backend_test.exclude(r'test_top_k_same_values_2d_cpu') + backend_test.exclude(r'test_top_k_same_values_cpu') + backend_test.exclude(r'test_top_k_same_values_largest_cpu') + backend_test.exclude(r'test_top_k_uint64_cpu') + #src/shape.cpp:367: lens: SHAPE: lens() called on a dynamic shape + backend_test.exclude(r'test_unique_length_1_cpu') + # + backend_test.exclude(r'test_averagepool_2d_ceil_last_window_starts_on_pad_cpu') + + def disabled_tests_int4(backend_test): # quantizelinear backend_test.exclude(r'test_quantizelinear_int4') @@ -1209,6 +1228,9 @@ def create_backend_test(testname=None, target_device=None): if version.parse(onnx.__version__) >= version.parse("1.17.0"): disabled_tests_onnx_1_17_0(backend_test) + if version.parse(onnx.__version__) >= version.parse("1.18.0"): + disabled_tests_onnx_1_18_0(backend_test) + # import all test cases at global scope to make # them visible to python.unittest. diff --git a/test/py/requirements-onnx.txt b/test/py/requirements-onnx.txt index d7cfa26772c..566c3db1adc 100644 --- a/test/py/requirements-onnx.txt +++ b/test/py/requirements-onnx.txt @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -onnx==1.17.0;python_version>="3.11" +onnx==1.18.0;python_version>="3.11" onnx==1.14.1;python_version<"3.11" protobuf==4.25.8 numpy==1.26.4;python_version>="3.11" diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index 1ccafed179e..5377dc99715 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -34,6 +34,9 @@ #include +#include + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN); static void opt_pooling(migraphx::module& m) { migraphx::rewrite_pooling rp; @@ -309,6 +312,77 @@ TEST_CASE(rewrite_pooling_dialtions_test5) test_rewrite(migraphx::op::pooling_mode::max); } +TEST_CASE(test_lower_lrn_to_pooling) +{ + migraphx::module m1; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input1 = m1.add_parameter("x", input_shape); + auto lrn1 = m1.add_instruction( + migraphx::make_op("lrn", + {{"alpha", 0.0001f}, {"beta", 0.75f}, {"bias", 1.0f}, {"size", 4}}), + input1); + m1.add_return({lrn1}); + } + // Apply the pass directly when the flag enabled + migraphx::rewrite_pooling rp{.rewrite_lrn = true}; + migraphx::dead_code_elimination dce; + rp.apply(m1); + dce.apply(m1); + + migraphx::module m2; + { + migraphx::shape input_shape{migraphx::shape::float_type, {1, 64, 55, 55}}; + auto input2 = m2.add_parameter("x", input_shape); + + auto x2 = m2.add_instruction(migraphx::make_op("mul"), input2, input2); + + auto transpose1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 2, 3, 1}}}), + x2); + + int64_t lrn_size = 4; + int64_t pad_left = (lrn_size - 1) / 2; + int64_t pad_right = lrn_size - 1 - pad_left; + std::vector expected_pads = {pad_left, pad_right}; + + auto avg = m2.add_instruction( + migraphx::make_op( + "pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", std::vector{1, lrn_size}}, + {"stride", std::vector{1, 1}}, + {"padding", std::vector{0, expected_pads[0], 0, expected_pads[1]}}, + {"dilations", std::vector{1, 1}}, + {"count_include_pad", true}}), + transpose1); + + auto transpose2 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", std::vector{0, 3, 1, 2}}}), + avg); + + auto k_lit = m2.add_literal(1.0f); + auto a_lit = m2.add_literal(0.0001f); + auto b_lit = m2.add_literal(0.75f); + + auto k_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), k_lit); + auto a_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), a_lit); + auto b_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", input_shape.lens()}}), b_lit); + + auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), a_mb, transpose2); + auto den = m2.add_instruction(migraphx::make_op("add"), k_mb, alpha_avg); + auto denpow = m2.add_instruction(migraphx::make_op("pow"), den, b_mb); + auto y = m2.add_instruction(migraphx::make_op("div"), input2, denpow); + + m2.add_return({y}); + } + + EXPECT(m1 == m2); +} + TEST_CASE(rewrite_avgpool_rank3_dil_test) { // 1D case 1, input is 3D diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index bfc3d61386b..8d548224762 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1815,6 +1815,104 @@ TEST_CASE(fp4x2_quant_dot_even) EXPECT(m1 == m2); } +TEST_CASE(fp4x2_quant_dot_trans_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 8, 24}}; + + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + auto packed_b = m1.add_parameter("weights", shape_packed_b); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto trans_b = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dq_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, trans_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_parameter("weights", shape_packed_b); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto trans_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b); + auto trans_scale_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), scale_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, trans_b, scale_a, trans_scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp4x2_quant_dot_const_b) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}}; + migraphx::shape shape_packed_b_gen{migraphx::shape::uint8_type, {1, 3, 24, 4}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 24, 8}}; + unsigned long seed = 826; + migraphx::literal b_lit = generate_literal(shape_packed_b_gen, seed); + migraphx::literal scale_b_lit = generate_literal(shape_scales_b, seed); + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + // avoiding visit fp4x2_type + auto packed_b = m1.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_literal(scale_b_lit); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_literal(scale_b_lit); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + // Test that unused qdq with pack_fp4, unpack_fp4 are removed TEST_CASE(fp4x2_even_remove_qdq) { diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt index bd57abea883..56b0a2cf2ba 100644 --- a/test/verify/CMakeLists.txt +++ b/test/verify/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,3 +42,7 @@ foreach(SECTION general reduce rnn conv gemm) ) endif() endforeach() + +# TODO: remove this when MIGraphX disables attn-like offloading to rocMLIR +# f32 not supported on navi3x +set_tests_properties(test_verify_rnn PROPERTIES ENVIRONMENT "MIGRAPHX_ENABLE_MLIR_GEG_FUSION=0") diff --git a/test/verify/test_conv_add_dot.cpp b/test/verify/test_conv_add_dot.cpp new file mode 100644 index 00000000000..a4ba2018f7f --- /dev/null +++ b/test/verify/test_conv_add_dot.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_conv_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(bias_literal); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto bcast_bias = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_bias); + + // Create a literal for dot (matmul) with shape {4, 3, 3, 3} + std::vector bias_add_lens = bias_add->get_shape().lens(); + // The shape is {4, 3, 3, 3}, so we want a rhs shape {4, 3, 3, 3} + migraphx::shape dot_rhs_shape{DType, bias_add_lens}; + std::vector dot_rhs_data(dot_rhs_shape.elements(), 1.0f); + auto dot_rhs = mm->add_literal(migraphx::literal{dot_rhs_shape, dot_rhs_data}); + + // Matmul (dot) with same shape, so this is elementwise matmul + auto dot = mm->add_instruction(migraphx::make_op("dot"), bias_add, dot_rhs); + mm->add_return({dot}); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_conv_add_dot; +template struct test_conv_add_dot; diff --git a/test/verify/test_dot_add_dot.cpp b/test/verify/test_dot_add_dot.cpp new file mode 100644 index 00000000000..a087e60e018 --- /dev/null +++ b/test/verify/test_dot_add_dot.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_dot_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + mm->add_return({dot2}); + return p; + } +}; + +template struct test_dot_add_dot; +template struct test_dot_add_dot; diff --git a/test/verify/test_lrn.cpp b/test/verify/test_lrn.cpp new file mode 100644 index 00000000000..2eae1a202e5 --- /dev/null +++ b/test/verify/test_lrn.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_lrn : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter( + "x", migraphx::shape{migraphx::shape::float_type, {1, ChannelSize, 28, 28}}); + mm->add_instruction( + migraphx::make_op( + "lrn", {{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", LrnSize}}), + x); + return p; + } +}; + +template struct test_lrn<32, 6>; +template struct test_lrn<32, 5>; +template struct test_lrn<31, 8>; +template struct test_lrn<31, 5>; diff --git a/test/verify/test_relu_lrn.cpp b/test/verify/test_relu_lrn.cpp index feec2ab21c6..d201bd77317 100644 --- a/test/verify/test_relu_lrn.cpp +++ b/test/verify/test_relu_lrn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/test_topk.cpp b/test/verify/test_topk.cpp index ad5ac3f2d6a..e2d089953e5 100644 --- a/test/verify/test_topk.cpp +++ b/test/verify/test_topk.cpp @@ -72,3 +72,6 @@ template struct test_topk; template struct test_topk; template struct test_topk; template struct test_topk; + +template struct test_topk; +template struct test_topk;