-
Notifications
You must be signed in to change notification settings - Fork 54
Refactor and fix creation of ElementWise Region for Gemm+Gemm like ops #1960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the creation of ElementWise Region for Gemm+Gemm like operations to fix a bug where the Depth-First Search (DFS) algorithm revisited operations multiple times, causing incorrect block argument indexing and eventual out-of-bounds access.
- Replaced the recursive DFS approach with a cached visitor pattern to eliminate redundant visits
- Introduced
ElementwiseRegionFinderstruct to encapsulate region finding and rewrite logic - Updated all affected pattern matchers to use the new cached approach
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| mlir/test/fusion/pr-e2e/gemm-gemm/mixr-gemm-gemm-multiple-traces-to-first-gemm.mlir | Adds end-to-end test case that exposes the multiple traces bug |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Refactors element-wise region finding from recursive DFS to cached visitor pattern |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
I tested commit 6d11fa2 with my branch of MIGraphX (which has the GEMM+GEMM fusion on the MIGraphX side) with the testcase and it compiles successfully e2e: As discussed elsewhere, commit bfb7605 has a faulty copilot fix. |
justinrosner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is something that can be addressed by a future ticket, but do you think it would be worthwhile creating some variant of the df_iterator framework for our Rock ops similar to what LLVM has? This could potentially help with code reusability when implementing custom DFS traversals of our ops.
| } | ||
| // Right now, this is a bit restricted that we only allow reshape-like | ||
| // ops between in the elementwise tree that get fused to the fusion point. | ||
| // TODO: however, the latest code gridwise-gemm-to-blockwise should tackle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a ticket opened for this TODO so that we don't lose track of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is an issue. I copied pasted earlier todo as it is.
We haven't hit any cases where it is not generating attention kernel because of this limitation yet.
Also i've seen in most (or all) of the cases tensor.expand and tensor.collapse along with tosa.add covers all the invertible transforms. e.g. broadcast, reshapes, squeeze/unsqueeze.
Yes sounds like a good idea. but we are only doing DFS in TosaToRock.cpp at this point afaict. We do use graph traversals with bufferDependencyAnalysis in some places. We will have to think about if it would be worth using |
#1960) * Refactor matching logic for elemwise tree
Motivation
Fixes https://github.com/ROCm/rocMLIR-internal/issues/1969
Technical Details
Current logic uses DFS without caching therefore it revisits some ops multiple times. This leads to adding blockArgument for the firstGemm multiple times. But then it only keeps index for one of the firstGemmIndex blockArguments i.e. last one.
This leads to mismatch in sizes of
preSoftmaxElementwiseInputs()and block arguments list. Which leads to eventual out of bound indexing intopreSoftmaxElementwiseInputs()Refactors logic so that it can cache ops/values found during "match" phase and use them directly during "rewrite" phase.
Test Plan
Added E2E test that exposes the problem and it passes.