diff --git a/Jenkinsfile b/Jenkinsfile index cfa98d43a4d..6dd7994ed89 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -199,7 +199,7 @@ rocmtest clang_debug: rocmnode('mi200+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_MLIR_GEG_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 266ec1fbe92..f9c9f309ab3 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -144,6 +144,14 @@ Model performance tunable variables change the compilation behavior of a model. | Default: Reduction fusions are turned off. + * - | ``MIGRAPHX_ENABLE_MLIR_GEG_FUSION`` + | Turns on GEMM+GEMM fusions in MLIR. + + - | ``1``: Turns on G+G fusions. + | ``0``: Returns to default behavior. + + | Default: GEMM+GEMM fusions are turned off. + * - | ``MIGRAPHX_MLIR_ENABLE_SPLITK`` | Turns on Split-k performance configurations during MLIR tuning. diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index ccf67de0d85..ec109d98b6d 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -339,6 +339,7 @@ struct MIGRAPHX_EXPORT module ins_dep_map calc_implicit_deps() const; void repeat_while_changes(std::size_t n, const std::function& f); + void localized_sort(instruction_ref start_ins, instruction_ref end_ins); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y); diff --git a/src/module.cpp b/src/module.cpp index 50479933940..0b94a36c5a0 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1600,6 +1600,31 @@ void module::repeat_while_changes(std::size_t n, const std::function& f) } } +// For topologically sorting a region in a module, canonically, such that the +// dependent chain between the two input instructions is last +void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins) +{ + // get the chain of instructions between start_ins and end_ins, inclusive + auto fusion_ins = find_instructions_between(start_ins, end_ins, this); + + // move all instructions between start_ins & end_ins that are not in the fusion chain + // to the start_ins. In order, moving to the same destination, this will naturally preserve + // the preexisting topological order of the module + for(auto it = std::next(start_ins); it != end_ins;) + { + if(fusion_ins.count(it) == 0) + { + auto next = std::next(it); // move_instruction updates the iterator + this->move_instruction(it, start_ins); + it = next; + } + else + { + ++it; + } + } +} + bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); } std::ostream& operator<<(std::ostream& os, const module& m) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index b036fb4cb0e..83d30c50bba 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -48,6 +48,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -779,6 +780,152 @@ struct find_mlir_fused_ops } }; +/** + * Fuses rocMLIR conv/dot -> pointwise -> dot chain + * into a mlir_op with submodule. + */ +struct find_mlir_fused_geg_ops +{ + mlir_mode conv_mode = mlir_mode::none; + mlir_mode dot_mode = mlir_mode::none; + + /* + * Matches: + * mlir_dot_or_conv -> + * pointwise -> + * dot + */ + auto matcher() const + { + auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) + .bind("first_gemm_based_op"); + auto elemwise = + mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise"); + return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise)) + .bind("second_gemm_op"); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto second_gemm_ins = r.result; + auto elemwise_ins = r.instructions["elemwise"]; + auto first_gemm_ins = r.instructions["first_gemm_based_op"]; + + auto* elemwise_module = elemwise_ins->module_inputs().front(); + auto elemwise_inputs = elemwise_ins->inputs(); + + // only one input to elemwise should depend on first_gemm + if(std::any_of(elemwise_inputs.begin(), elemwise_inputs.end(), [&](const auto& i) { + return i != first_gemm_ins and reaches(first_gemm_ins, i); + })) + return; + + // only one input to second_gemm should depend on elemwise + auto second_gemm_inputs = second_gemm_ins->inputs(); + if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) { + return i != elemwise_ins and reaches(elemwise_ins, i); + })) + return; + + std::unordered_map map_ins; + module_ref mm = + mpm.create_module("mlir_" + elemwise_ins->module_inputs().front()->name() + "_geg"); + mm->set_bypass(); + fuse_input_ops(mm, first_gemm_ins->inputs(), &map_ins); + + // need to track multi-user scenarios for both intermediates + bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1; + bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1; + + // add the first gemm to the module + std::vector first_gemm_mapped_inputs; + first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size()); + std::transform(first_gemm_ins->inputs().begin(), + first_gemm_ins->inputs().end(), + std::back_inserter(first_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto first_gemm_in_module = + mm->add_instruction(first_gemm_ins->get_operator(), first_gemm_mapped_inputs); + map_ins[first_gemm_ins] = first_gemm_in_module; + + // fuse external inputs for the elemwise operation + fuse_input_ops(mm, elemwise_inputs, &map_ins); + + // fuse elemwise submodule + auto elemwise_rins = + mm->fuse(*elemwise_module, elemwise_inputs, &map_ins, &insert_pointwise); + assert(elemwise_rins.size() == 1); + map_ins[elemwise_ins] = elemwise_rins.front(); + + // fuse external inputs for the second gemm + fuse_input_ops(mm, second_gemm_inputs, &map_ins); + + // add the second gemm to the new module + std::vector second_gemm_mapped_inputs; + second_gemm_mapped_inputs.reserve(second_gemm_inputs.size()); + std::transform(second_gemm_inputs.begin(), + second_gemm_inputs.end(), + std::back_inserter(second_gemm_mapped_inputs), + [&](auto input) { return map_ins.at(input); }); + auto second_gemm_in_module = + mm->add_instruction(second_gemm_ins->get_operator(), second_gemm_mapped_inputs); + map_ins[second_gemm_ins] = second_gemm_in_module; + + // primary output is the last gemm, which should be the first output + std::vector return_vals; + return_vals.push_back(second_gemm_in_module); + + if(elemwise_has_multi_outs) + { + return_vals.push_back(map_ins[elemwise_ins]); + } + if(first_gemm_has_multi_outs) + { + return_vals.push_back(map_ins[first_gemm_ins]); + } + mm->add_return(return_vals); + auto inputs = find_inputs(map_ins, &mpm.get_module(), mm); + + // sort fusion section of module such that any external inputs are moved before the fusion + // so that we can safely place the fused mod in the multi-out case at the beginning of the + // chain + mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins); + + auto fused_ins = + mpm.get_module().insert_instruction(first_gemm_ins, + mlir_op{second_gemm_ins->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); + + if(first_gemm_has_multi_outs or elemwise_has_multi_outs) + { + std::size_t output_idx = 0; + if(elemwise_has_multi_outs) + { + auto elemwise_result = mpm.get_module().insert_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + mpm.get_module().replace_instruction(elemwise_ins, elemwise_result); + } + if(first_gemm_has_multi_outs) + { + mpm.get_module().replace_instruction( + first_gemm_ins, + migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}), + fused_ins); + } + mpm.get_module().replace_instruction( + second_gemm_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + } + else + { + // simple single output case + mpm.get_module().replace_instruction(second_gemm_ins, fused_ins); + } + } +}; + template struct find_mlir_standalone_op { @@ -1300,6 +1447,15 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches(mpm, find_mlir_attention_op{}); mpm.run_pass(dead_code_elimination{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + { + match::find_matches( + mpm, + find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), + .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + } + match::find_matches( mpm, find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 57c401b8359..5532da8832b 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -40,6 +40,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); struct non_mlir_op { @@ -1773,6 +1774,847 @@ TEST_CASE(unpack_fp4_dot_odd) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dot_add_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2,4} + auto add = + add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); // {2,4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); // {4,2} + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s3); + auto y = mm->add_parameter("y", s4); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_square) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_mul_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {4, 5}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto mul = add_pointwise(p1, "main:pointwise0", {dot1, x}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), mul, y); + mm->add_return({dot2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", s2); + auto y = mm->add_parameter("y", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, x, y}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[3]); + return std::make_tuple(dot2->get_operator(), dot2); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add) +{ + migraphx::shape is{migraphx::shape::float_type, {4, 14, 122, 122}}; + migraphx::shape ys{migraphx::shape::float_type, {4, 56, 122, 122}}; + migraphx::shape ws{migraphx::shape::float_type, {56, 14, 1, 1}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto fused = add_mlir(p2, + "mlir_main:pointwise0", + {x, w, y}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = + pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + return std::make_tuple(c->get_operator(), add); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add_dot) +{ + migraphx::shape is{migraphx::shape::float_type, {2, 4, 8, 8}}; + migraphx::shape ys{migraphx::shape::float_type, {2, 8, 8, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {8, 4, 1, 1}}; + migraphx::shape zs{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + auto add = add_pointwise(p1, "main:pointwise0", {conv, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", is); + auto y = mm->add_parameter("y", ys); + auto w = mm->add_parameter("w", ws); + auto z = mm->add_parameter("z", zs); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {x, w, y, z}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto c = pm->add_instruction( + migraphx::make_op("convolution"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), c, inputs[2]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add) +// G ->optional R -> E fusion +// G has two users, one external to fusion +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto fused = + add_mlir(p2, "mlir_main:pointwise0", {a, b, c}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + return std::make_tuple(dot1->get_operator(), + std::vector{add, dot1}); + }); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_dot = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot); + mm->add_return({get_add, transpose}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 5}}; + migraphx::shape s3{migraphx::shape::float_type, {5, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s2); + auto d = mm->add_parameter("d", s3); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 0}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_with_transpose) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto d_t = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d_t); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto d_t = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, d_t); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_two_externals) +// GEG fusion has two outputs, E has external user +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto external_t2 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, external_t1, external_t2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t1 = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), d); + auto external_t2 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t1, external_t2}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before) +// GEG fusion has two outputs, E has external user. +// Base case for testing inputs being defined within the span +// of will-be-fused ops +// This also shows the relu being fused, since it is a unary op +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops +// This also shows the relu being fused, since it is a unary op. +// Result should be, and is, equivalent to the previous test +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_relu); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, + "main:pointwise1:mlir_main:pointwise0_geg", + {d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto external_relu = pm->add_instruction(migraphx::make_op("relu"), inputs[0]); + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, external_relu); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_before_in_chain) +// GEG fusion has two outputs, E has external user +// Base case for inputs being defined within the span of will-be-fused ops, including +// longer chain of logic, for both cases of input fusion. When enabled, +// the mul gets fused into the GEG fusion. +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_multi_user_dot_input_used_after_in_chain) +// GEG fusion has two outputs, E has external user +// Testing inputs being defined within the span of will-be-fused ops, including +// longer chain of logic +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto external_relu = add_pointwise(p1, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p1, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({add, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p2, "main:pointwise1", {d}, single_pointwise("relu")); + auto external_mul = + add_pointwise(p2, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_geg", + {a, b, c, external_mul}, + [=](auto* pm, const auto& inputs) { + auto dot1 = + pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + migraphx::program p3; + { + auto* mm = p3.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto external_relu = add_pointwise(p3, "main:pointwise1", {d}, single_pointwise("relu")); + auto fused = add_mlir( + p3, + "main:pointwise2:mlir_main:pointwise0_geg", + {external_relu, d, a, b, c}, + [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[2], inputs[3]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[4]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[1]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, mul); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto external_t = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_add, external_t}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + if(migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + EXPECT(p1.sort() == p3.sort()); + else + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_pw_multi_user_dot) +// GEG fusion has two outputs, E has external user, E is multiple elemwise ops +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto elemwise = + add_pointwise(p1, "main:pointwise0", {dot1, c, d}, [=](auto* pm, const auto& inputs) { + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("mul"), add, inputs.at(2)); + }); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), elemwise, e); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot2); + mm->add_return({elemwise, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto e = mm->add_parameter("e", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d, e}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, inputs[3]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), mul, inputs[4]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, mul}); + }); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto get_mul = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot2); + mm->add_return({get_mul, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_user_add_dot) +// GEG fusion has two outputs (first G has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_add_dot_both_multi_user) +// GEG fusion has three outputs (first G has external user, E has external user) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot1, c}, single_pointwise("add")); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot1); + mm->add_return({add, dot2, transpose}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto fused = add_mlir( + p2, "mlir_main:pointwise0_geg", {a, b, c, d}, [=](auto* pm, const auto& inputs) { + auto dot1 = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot1, inputs[2]); + auto dot2 = pm->add_instruction(migraphx::make_op("dot"), add, inputs[3]); + return std::make_tuple(dot1->get_operator(), + std::vector{dot2, add, dot1}); + }); + auto get_dot1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), fused); + auto get_elemwise = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto get_dot2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto transpose = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), get_dot1); + mm->add_return({get_elemwise, get_dot2, transpose}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) diff --git a/test/module_test.cpp b/test/module_test.cpp index ac3079033de..270bfcd63e4 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1369,4 +1369,50 @@ TEST_CASE(pathological_dfs_graph_sort) EXPECT(is_sorted(m)); } +TEST_CASE(localized_sort) +{ + migraphx::shape s1{migraphx::shape::float_type, {2, 3}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 4}}; + migraphx::shape s3{migraphx::shape::float_type, {2, 4}}; + migraphx::shape s4{migraphx::shape::float_type, {4, 2}}; + migraphx::program p; + auto* mm = p.get_main_module(); + + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto c = mm->add_parameter("c", s3); + auto d = mm->add_parameter("d", s4); + + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {2, 4} + auto external_relu = + add_pointwise(p, "main:pointwise1", {d}, single_pointwise("relu")); // {4, 3} + auto external_mul = + add_pointwise(p, "main:pointwise2", {external_relu, d}, single_pointwise("mul")); // {4, 3} + auto add = add_pointwise(p, "main:pointwise0", {dot1, c}, single_pointwise("add")); // {2, 4} + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, external_mul); // {2, 3} + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot2); + mm->add_return({add, transpose}); + + // Perform localized sort between dot1 and dot2 + mm->localized_sort(dot1, dot2); + + // Verify the module is still topologically sorted overall + EXPECT(is_sorted(*mm)); + + // Verify external operations moved before the fusion chain + EXPECT(std::distance(mm->begin(), external_relu) < std::distance(mm->begin(), dot1)); + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify the fusion chain ordering is preserved: dot1 < add < dot2 + EXPECT(std::distance(mm->begin(), dot1) < std::distance(mm->begin(), add)); + EXPECT(std::distance(mm->begin(), add) < std::distance(mm->begin(), dot2)); + + // Verify external_mul is before dot1 (since dot2 depends on external_mul) + EXPECT(std::distance(mm->begin(), external_mul) < std::distance(mm->begin(), dot1)); + + // Verify transpose remains after dot2 + EXPECT(std::distance(mm->begin(), dot2) < std::distance(mm->begin(), transpose)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/CMakeLists.txt b/test/verify/CMakeLists.txt index bd57abea883..56b0a2cf2ba 100644 --- a/test/verify/CMakeLists.txt +++ b/test/verify/CMakeLists.txt @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -42,3 +42,7 @@ foreach(SECTION general reduce rnn conv gemm) ) endif() endforeach() + +# TODO: remove this when MIGraphX disables attn-like offloading to rocMLIR +# f32 not supported on navi3x +set_tests_properties(test_verify_rnn PROPERTIES ENVIRONMENT "MIGRAPHX_ENABLE_MLIR_GEG_FUSION=0") diff --git a/test/verify/test_conv_add_dot.cpp b/test/verify/test_conv_add_dot.cpp new file mode 100644 index 00000000000..a4ba2018f7f --- /dev/null +++ b/test/verify/test_conv_add_dot.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_conv_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(bias_literal); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); + auto bcast_bias = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_bias); + + // Create a literal for dot (matmul) with shape {4, 3, 3, 3} + std::vector bias_add_lens = bias_add->get_shape().lens(); + // The shape is {4, 3, 3, 3}, so we want a rhs shape {4, 3, 3, 3} + migraphx::shape dot_rhs_shape{DType, bias_add_lens}; + std::vector dot_rhs_data(dot_rhs_shape.elements(), 1.0f); + auto dot_rhs = mm->add_literal(migraphx::literal{dot_rhs_shape, dot_rhs_data}); + + // Matmul (dot) with same shape, so this is elementwise matmul + auto dot = mm->add_instruction(migraphx::make_op("dot"), bias_add, dot_rhs); + mm->add_return({dot}); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_conv_add_dot; +template struct test_conv_add_dot; diff --git a/test/verify/test_dot_add_dot.cpp b/test/verify/test_dot_add_dot.cpp new file mode 100644 index 00000000000..a087e60e018 --- /dev/null +++ b/test/verify/test_dot_add_dot.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_dot_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto c = mm->add_parameter("c", s); + auto d = mm->add_parameter("d", s); + auto dot1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("add"), dot1, c); + auto dot2 = mm->add_instruction(migraphx::make_op("dot"), add, d); + mm->add_return({dot2}); + return p; + } +}; + +template struct test_dot_add_dot; +template struct test_dot_add_dot;