Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
367f9e2
single output gemm elem gemm fusion w mlir
bdevorem Aug 26, 2025
5e4bd88
dead code elim after pass
bdevorem Aug 26, 2025
b6c9b53
make modular; add tag feature to mlir_ops
bdevorem Aug 29, 2025
917992b
working on multi output
bdevorem Aug 29, 2025
6b4b672
handle multi out case: GE single return val, multi user
bdevorem Aug 29, 2025
3211837
fix test case
bdevorem Aug 30, 2025
5f1613d
multi return from GE fusion supported in GEG fusion
bdevorem Aug 30, 2025
48bea0c
clang format
bdevorem Aug 30, 2025
84b8953
env var, more formatting, comments, remove change to requirements
bdevorem Aug 30, 2025
2d5b25d
Revert "multi return from GE fusion supported in GEG fusion"
bdevorem Sep 10, 2025
5010310
remove tag, handle multi output for distinct intermediates (instead o…
bdevorem Sep 10, 2025
312bfb0
format/tidy/cppcheck warnings
bdevorem Sep 10, 2025
ac1e0a6
clang format again...
bdevorem Sep 10, 2025
ed84079
trace flag to jenkins, clang format again
bdevorem Sep 11, 2025
b6a40dc
remove dead code elim
bdevorem Sep 12, 2025
63cb9b6
Revert "trace flag to jenkins, clang format again"
bdevorem Sep 12, 2025
4c15561
remove unnecessary newlines
bdevorem Sep 12, 2025
4359853
this is gonna be a lot of output but print the mod before the assertion
bdevorem Sep 12, 2025
2c81edd
handle bug where when replacing intermediates with the get_tuple_elem…
bdevorem Sep 12, 2025
7a3f2a4
start inserting instructions before first gemm in the fusion in order…
bdevorem Sep 12, 2025
68b9194
clang format kmn
bdevorem Sep 12, 2025
88bac90
disable geg fusion for rnn tests
bdevorem Sep 22, 2025
42e786c
verify tests
bdevorem Sep 22, 2025
b25b3c4
fix merge conflict
bdevorem Sep 22, 2025
fe0c308
add comment
bdevorem Sep 22, 2025
316cd30
rename test to match contents; clang format
bdevorem Sep 22, 2025
11ca8ee
formatting, fix winky dims in conv test
bdevorem Sep 23, 2025
47604b5
more formatting...
bdevorem Sep 23, 2025
510da97
handle the case where intermediate external inputs are not placed abo…
bdevorem Sep 24, 2025
fade73a
add conv test to conv group
bdevorem Sep 24, 2025
3257bb2
review comments; adds sort for localized sections of modules
bdevorem Sep 30, 2025
a2bc265
clang-format
bdevorem Sep 30, 2025
7a22ab3
make some tests rectangular
bdevorem Sep 30, 2025
6b5d7cb
clang-format
bdevorem Sep 30, 2025
047d7ac
edit tests to handle both if input fusion enabled or disabled
bdevorem Sep 30, 2025
699e02b
clang-tidy
bdevorem Sep 30, 2025
35005c3
test for localized module sort
bdevorem Oct 1, 2025
f43e0a1
Merge branch 'develop' into bdevorem/fuse-geg
causten Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ rocmtest clang_debug: rocmnode('mi200+') { cmake_build ->
}
}, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot,convolution_backwards', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_MLIR_GEG_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) {
def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
Expand Down
8 changes: 8 additions & 0 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ Model performance tunable variables change the compilation behavior of a model.

| Default: Reduction fusions are turned off.

* - | ``MIGRAPHX_ENABLE_MLIR_GEG_FUSION``
| Turns on GEMM+GEMM fusions in MLIR.

- | ``1``: Turns on G+G fusions.
| ``0``: Returns to default behavior.

| Default: GEMM+GEMM fusions are turned off.

* - | ``MIGRAPHX_MLIR_ENABLE_SPLITK``
| Turns on Split-k performance configurations during MLIR tuning.

Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ struct MIGRAPHX_EXPORT module
ins_dep_map calc_implicit_deps() const;

void repeat_while_changes(std::size_t n, const std::function<void()>& f);
void localized_sort(instruction_ref start_ins, instruction_ref end_ins);

MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);
Expand Down
25 changes: 25 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,31 @@ void module::repeat_while_changes(std::size_t n, const std::function<void()>& f)
}
}

// For topologically sorting a region in a module, canonically, such that the
// dependent chain between the two input instructions is last
void module::localized_sort(instruction_ref start_ins, instruction_ref end_ins)
{
// get the chain of instructions between start_ins and end_ins, inclusive
auto fusion_ins = find_instructions_between(start_ins, end_ins, this);

// move all instructions between start_ins & end_ins that are not in the fusion chain
// to the start_ins. In order, moving to the same destination, this will naturally preserve
// the preexisting topological order of the module
for(auto it = std::next(start_ins); it != end_ins;)
{
if(fusion_ins.count(it) == 0)
{
auto next = std::next(it); // move_instruction updates the iterator
this->move_instruction(it, start_ins);
it = next;
}
else
{
++it;
}
}
}

bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); }

std::ostream& operator<<(std::ostream& os, const module& m)
Expand Down
156 changes: 156 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
Expand Down Expand Up @@ -779,6 +780,152 @@ struct find_mlir_fused_ops
}
};

/**
* Fuses rocMLIR conv/dot -> pointwise -> dot chain
* into a mlir_op with submodule.
*/
struct find_mlir_fused_geg_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;

/*
* Matches:
* mlir_dot_or_conv <binds to "first_gemm_based_op"> ->
* pointwise <binds to "pointwise_op"> ->
* dot <matcher result, binds to "second_gemm_op">
*/
auto matcher() const
{
auto first_dot_or_conv = match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))
.bind("first_gemm_based_op");
auto elemwise =
mlir_pointwise()(match::any_of[match::inputs()](first_dot_or_conv)).bind("elemwise");
return is_mlir_dot(dot_mode)(match::any_of[match::inputs()](elemwise))
.bind("second_gemm_op");
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto second_gemm_ins = r.result;
auto elemwise_ins = r.instructions["elemwise"];
auto first_gemm_ins = r.instructions["first_gemm_based_op"];

auto* elemwise_module = elemwise_ins->module_inputs().front();
auto elemwise_inputs = elemwise_ins->inputs();

// only one input to elemwise should depend on first_gemm
if(std::any_of(elemwise_inputs.begin(), elemwise_inputs.end(), [&](const auto& i) {
return i != first_gemm_ins and reaches(first_gemm_ins, i);
}))
return;

// only one input to second_gemm should depend on elemwise
auto second_gemm_inputs = second_gemm_ins->inputs();
if(std::any_of(second_gemm_inputs.begin(), second_gemm_inputs.end(), [&](const auto& i) {
return i != elemwise_ins and reaches(elemwise_ins, i);
}))
return;

std::unordered_map<instruction_ref, instruction_ref> map_ins;
module_ref mm =
mpm.create_module("mlir_" + elemwise_ins->module_inputs().front()->name() + "_geg");
mm->set_bypass();
fuse_input_ops(mm, first_gemm_ins->inputs(), &map_ins);

// need to track multi-user scenarios for both intermediates
bool first_gemm_has_multi_outs = first_gemm_ins->outputs().size() > 1;
bool elemwise_has_multi_outs = elemwise_ins->outputs().size() > 1;

// add the first gemm to the module
std::vector<instruction_ref> first_gemm_mapped_inputs;
first_gemm_mapped_inputs.reserve(first_gemm_ins->inputs().size());
std::transform(first_gemm_ins->inputs().begin(),
first_gemm_ins->inputs().end(),
std::back_inserter(first_gemm_mapped_inputs),
[&](auto input) { return map_ins.at(input); });
auto first_gemm_in_module =
mm->add_instruction(first_gemm_ins->get_operator(), first_gemm_mapped_inputs);
map_ins[first_gemm_ins] = first_gemm_in_module;

// fuse external inputs for the elemwise operation
fuse_input_ops(mm, elemwise_inputs, &map_ins);

// fuse elemwise submodule
auto elemwise_rins =
mm->fuse(*elemwise_module, elemwise_inputs, &map_ins, &insert_pointwise);
assert(elemwise_rins.size() == 1);
map_ins[elemwise_ins] = elemwise_rins.front();

// fuse external inputs for the second gemm
fuse_input_ops(mm, second_gemm_inputs, &map_ins);

// add the second gemm to the new module
std::vector<instruction_ref> second_gemm_mapped_inputs;
second_gemm_mapped_inputs.reserve(second_gemm_inputs.size());
std::transform(second_gemm_inputs.begin(),
second_gemm_inputs.end(),
std::back_inserter(second_gemm_mapped_inputs),
[&](auto input) { return map_ins.at(input); });
auto second_gemm_in_module =
mm->add_instruction(second_gemm_ins->get_operator(), second_gemm_mapped_inputs);
map_ins[second_gemm_ins] = second_gemm_in_module;

// primary output is the last gemm, which should be the first output
std::vector<instruction_ref> return_vals;
return_vals.push_back(second_gemm_in_module);

if(elemwise_has_multi_outs)
{
return_vals.push_back(map_ins[elemwise_ins]);
}
if(first_gemm_has_multi_outs)
{
return_vals.push_back(map_ins[first_gemm_ins]);
}
mm->add_return(return_vals);
auto inputs = find_inputs(map_ins, &mpm.get_module(), mm);

// sort fusion section of module such that any external inputs are moved before the fusion
// so that we can safely place the fused mod in the multi-out case at the beginning of the
// chain
mpm.get_module().localized_sort(first_gemm_ins, second_gemm_ins);

auto fused_ins =
mpm.get_module().insert_instruction(first_gemm_ins,
mlir_op{second_gemm_ins->get_operator()},
mlir_contiguous(mpm, inputs),
{mm});

if(first_gemm_has_multi_outs or elemwise_has_multi_outs)
{
std::size_t output_idx = 0;
if(elemwise_has_multi_outs)
{
auto elemwise_result = mpm.get_module().insert_instruction(
first_gemm_ins,
migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}),
fused_ins);
mpm.get_module().replace_instruction(elemwise_ins, elemwise_result);
}
if(first_gemm_has_multi_outs)
{
mpm.get_module().replace_instruction(
first_gemm_ins,
migraphx::make_op("get_tuple_elem", {{"index", ++output_idx}}),
fused_ins);
}
mpm.get_module().replace_instruction(
second_gemm_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
}
else
{
// simple single output case
mpm.get_module().replace_instruction(second_gemm_ins, fused_ins);
}
}
};

template <auto Matcher>
struct find_mlir_standalone_op
{
Expand Down Expand Up @@ -1300,6 +1447,15 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches(mpm, find_mlir_attention_op{});
mpm.run_pass(dead_code_elimination{});

if(enabled(MIGRAPHX_ENABLE_MLIR_GEG_FUSION{}))
{
match::find_matches(
mpm,
find_mlir_fused_geg_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
mpm.run_pass(dead_code_elimination{});
}

match::find_matches(
mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
Expand Down
Loading
Loading