Skip to content

[Backend] Refactor OptimizeDescriptorEncoding to common path#9709

Merged
antiagainst merged 6 commits into
triton-lang:mainfrom
sriakrish:nfc-refactor-opt-desc-encoding
Mar 20, 2026
Merged

[Backend] Refactor OptimizeDescriptorEncoding to common path#9709
antiagainst merged 6 commits into
triton-lang:mainfrom
sriakrish:nfc-refactor-opt-desc-encoding

Conversation

@sriakrish
Copy link
Copy Markdown
Contributor

  • Move common utilties of this to DescriptorUtils
  • Provide functors to customize backend specific decisions

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because it is a NFC refactoring.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@sriakrish sriakrish requested a review from ptillet as a code owner March 13, 2026 02:09
@sriakrish
Copy link
Copy Markdown
Contributor Author

sriakrish commented Mar 13, 2026

We are working on improving our handling of tensors descriptors. It is split into two parts:

  • Part 1 (this PR): NFC refactor moving common utility to DescriptorUtils and keeping backend specific decisions such as choosing a fallback layout, finding encoding from uses, etc., in the pass
  • Part 2: Add OptimizeDescriptorEncoding pass for AMD backend re-using the common functions from DescriptorUtils

cc: @antiagainst

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

What part of the logic will be different for AMD? The logic in this pass seems fairly generic except for the part about NVMMASharedEncodingAttr which are layout specifics to TMAs. I wonder if we really need a fully separate pass

@sriakrish
Copy link
Copy Markdown
Contributor Author

sriakrish commented Mar 13, 2026

What part of the logic will be different for AMD? The logic in this pass seems fairly generic except for the part about NVMMASharedEncodingAttr which are layout specifics to TMAs. I wonder if we really need a fully separate pass

Our current implementation requires these four functions

  1. getFallbackSharedLayout : We assign padded shared layouts as fallback for most cases determined by hardware features and use swizzled only under certain conditions
  2. updateEncodingForShape: Consequently, this function handles padded layouts and swizzled.
  3. findEncodingFromUsers: In this function, we walk the uses of the descriptor load and check if it leads to dot operands and derive padded layouts. It basically moves the padding decisions made here to this pass.
  4. forcedToDefault: We need this one because of ReinterpretTensorDescOp. We are forcing to default on just call and return ops.

@sriakrish sriakrish force-pushed the nfc-refactor-opt-desc-encoding branch from 323776d to c63d6db Compare March 14, 2026 14:58
@sriakrish sriakrish marked this pull request as draft March 14, 2026 22:42
@sriakrish sriakrish force-pushed the nfc-refactor-opt-desc-encoding branch from c63d6db to ca53687 Compare March 19, 2026 23:40
@sriakrish sriakrish marked this pull request as ready for review March 19, 2026 23:42
@antiagainst antiagainst changed the title [NFC] Refactor OptimizeDescriptorEncoding [Backend][NFC] Refactor OptimizeDescriptorEncoding to common path Mar 19, 2026
@antiagainst antiagainst changed the title [Backend][NFC] Refactor OptimizeDescriptorEncoding to common path [Backend] Refactor OptimizeDescriptorEncoding to common path Mar 19, 2026
@sriakrish
Copy link
Copy Markdown
Contributor Author

@ThomasRaoux @antiagainst

I have updated this PR. We now require only two callbacks:

  1. buildFallbackSharedEncoding - this is backend specific fallback encoding to build
  2. isCompatibleEncoding - using this to check if the encoding is backend compatible

Rest of the infrastructure is common.

updateEncodingForShape is now commonly shared between both backends
findLoadEncodingFromUsers is also common for both backends. It first checks for an (discardable) attribute tt.desired_encoding on descriptor loads, and the rest of this is unchanged. On the AMD side, we populate this attribute with a padded encoding from its usage in dot. And we let findLoadEncodingFromUsers simply pick this when it is available.

assignMemoryLayouts(m);

// Fallback shared encoding callback
auto buildFallbackSharedEncoding =
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than using lambda, maybe just create it as a free function like isTMACompatibleEncoding to be symmetric.

ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),
order, newCgaEnc);
}
if (auto paddedEnc = dyn_cast<ttg::PaddedSharedEncodingAttr>(encoding)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can remove this part and add it later together with the AMD logic.

/// Utility class to assign memory layouts to tensor descriptors in a module.
class AssignDescriptorMemoryLayouts {
public:
AssignDescriptorMemoryLayouts() = default;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this constructor needed?

@sriakrish
Copy link
Copy Markdown
Contributor Author

@antiagainst Thank you, PR updated with suggested revisions

Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense but I still think it should be refactored a little bit hopefully my comments make sense to you

Comment on lines +16 to +25
struct DescriptorAnalysisCallbacks {
/// Callback to check for compatible shared encoding
llvm::function_ref<bool(Attribute)> isCompatibleSharedEncoding;

/// create a fallback encoding given the shape, order, cga layout and
/// element type
llvm::function_ref<Attribute(mlir::MLIRContext *, ArrayRef<int64_t>,
ArrayRef<unsigned>, CGAEncodingAttr, Type)>
buildFallbackSharedEncoding;
};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a struct with a bunch of callbacks seems like a convoluted way to make virtual functions. How about we make those virtual functions in AssignDescriptorMemoryLayouts or DescriptorAnalysisCallbacks?

#include "llvm/ADT/PriorityWorklist.h"
#include <unordered_set>

namespace ttg = mlir::triton::gpu;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall the idea to separate things out make sense but on the style I think calling this file Utils is misleading. At this point this is really the whole implementation of the pass rather than a set of generic utils functions.

My suggestion is make a generic class with some overloaded functions and inherit those in a target specific way.

I know the result is very similar but I think mixing up passes and utils is going to be confusing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback. Agreed, I have renamed the file.

I have a new revision which addresses the following:

  • A base class AssignDescriptorMemoryLayouts now implements the core logic for layout assignment and provides virtual methods for backend overloads. Moved the core logic functions to the class.
  • Renamed from DescriptorUtils to DescriptorMemoryLayouts

#include <unordered_set>

namespace mlir::triton::gpu {
// Forward declarations
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: meaningless comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it.

- Move common utilties of this to DescriptorUtils
- Provide functors to customize backend specific decisions
- Introduce only two callbacks
- Callback for building backend fallback layout
- Callback for checking backend compatible encoding
- Move updateEncodingForShape to common place
- Adapt updateEncodingForShape to handle padded layout
- Fix comment
@sriakrish sriakrish force-pushed the nfc-refactor-opt-desc-encoding branch from 92ff3a7 to e028faf Compare March 20, 2026 19:00
@antiagainst antiagainst merged commit d48908a into triton-lang:main Mar 20, 2026
9 checks passed
raymondtay pushed a commit to raymondtay/triton that referenced this pull request Mar 22, 2026
…lang#9709)

A base class `AssignDescriptorMemoryLayouts` now implements
the core logic for layout assignment and provides virtual methods
for backend overloads. Moved the core logic functions to the class.
jvican pushed a commit to jvican/triton that referenced this pull request Mar 27, 2026
…lang#9709)

A base class `AssignDescriptorMemoryLayouts` now implements
the core logic for layout assignment and provides virtual methods
for backend overloads. Moved the core logic functions to the class.
plognjen pushed a commit to plognjen/triton that referenced this pull request Apr 14, 2026
…lang#9709)

A base class `AssignDescriptorMemoryLayouts` now implements
the core logic for layout assignment and provides virtual methods
for backend overloads. Moved the core logic functions to the class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants