Skip to content

Conversation

@umangyadav
Copy link
Member

@umangyadav umangyadav commented Aug 26, 2025

Motivation

Fixes https://github.com/ROCm/rocMLIR-internal/issues/1969

Technical Details

Current logic uses DFS without caching therefore it revisits some ops multiple times. This leads to adding blockArgument for the firstGemm multiple times. But then it only keeps index for one of the firstGemmIndex blockArguments i.e. last one.

This leads to mismatch in sizes of preSoftmaxElementwiseInputs() and block arguments list. Which leads to eventual out of bound indexing into preSoftmaxElementwiseInputs()

Refactors logic so that it can cache ops/values found during "match" phase and use them directly during "rewrite" phase.

Test Plan

Added E2E test that exposes the problem and it passes.

@umangyadav umangyadav requested a review from causten as a code owner August 26, 2025 20:47
@umangyadav umangyadav self-assigned this Aug 26, 2025
@umangyadav umangyadav requested review from Copilot, djramic, justinrosner and stefankoncarevic and removed request for causten August 26, 2025 20:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the creation of ElementWise Region for Gemm+Gemm like operations to fix a bug where the Depth-First Search (DFS) algorithm revisited operations multiple times, causing incorrect block argument indexing and eventual out-of-bounds access.

  • Replaced the recursive DFS approach with a cached visitor pattern to eliminate redundant visits
  • Introduced ElementwiseRegionFinder struct to encapsulate region finding and rewrite logic
  • Updated all affected pattern matchers to use the new cached approach

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir Adds end-to-end test case that exposes the multiple traces bug
mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Refactors element-wise region finding from recursive DFS to cached visitor pattern

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@bdevorem
Copy link
Contributor

bdevorem commented Aug 26, 2025

I tested commit 6d11fa2 with my branch of MIGraphX (which has the GEMM+GEMM fusion on the MIGraphX side) with the testcase and it compiles successfully e2e:

[2025-08-26 22:31:40]
[ MIGraphX Version: 2.14.0.cb9bc5d01-dirty ] Complete(0.659944s): ./build/bin/driver compile test.mxr

As discussed elsewhere, commit bfb7605 has a faulty copilot fix.

Copy link
Contributor

@justinrosner justinrosner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is something that can be addressed by a future ticket, but do you think it would be worthwhile creating some variant of the df_iterator framework for our Rock ops similar to what LLVM has? This could potentially help with code reusability when implementing custom DFS traversals of our ops.

}
// Right now, this is a bit restricted that we only allow reshape-like
// ops between in the elementwise tree that get fused to the fusion point.
// TODO: however, the latest code gridwise-gemm-to-blockwise should tackle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a ticket opened for this TODO so that we don't lose track of it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is an issue. I copied pasted earlier todo as it is.

We haven't hit any cases where it is not generating attention kernel because of this limitation yet.

Also i've seen in most (or all) of the cases tensor.expand and tensor.collapse along with tosa.add covers all the invertible transforms. e.g. broadcast, reshapes, squeeze/unsqueeze.

@umangyadav
Copy link
Member Author

Maybe this is something that can be addressed by a future ticket, but do you think it would be worthwhile creating some variant of the df_iterator framework for our Rock ops similar to what LLVM has? This could potentially help with code reusability when implementing custom DFS traversals of our ops.

Yes sounds like a good idea. but we are only doing DFS in TosaToRock.cpp at this point afaict. We do use graph traversals with bufferDependencyAnalysis in some places. We will have to think about if it would be worth using df_iterator.

@umangyadav umangyadav merged commit 97a0085 into develop Aug 28, 2025
16 of 22 checks passed
@umangyadav umangyadav deleted the fixGemmBlockArgs branch August 28, 2025 20:02
umangyadav added a commit that referenced this pull request Aug 28, 2025
umangyadav added a commit that referenced this pull request Sep 2, 2025
…emm like ops (#1965)

* Refactor and fix creation of ElementWise Region for Gemm+Gemm like ops (#1960)

* Refactor matching logic for elemwise tree

* Fix merge issues
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants