Skip to content

[AMD] Enable ds_read_tr* lowering for PartitionedSharedEncodingAttr#10062

Merged
antiagainst merged 3 commits into
triton-lang:mainfrom
plognjen:ds_transpose_partitioned
Apr 20, 2026
Merged

[AMD] Enable ds_read_tr* lowering for PartitionedSharedEncodingAttr#10062
antiagainst merged 3 commits into
triton-lang:mainfrom
plognjen:ds_transpose_partitioned

Conversation

@plognjen
Copy link
Copy Markdown
Contributor

Extends the AMD ds_read_tr* local-load lowering to accept PartitionedSharedEncodingAttr as the source encoding.
Previously the pattern bailed out as soon as it saw a partitioned shared encoding, forcing a slower generic
local-load lowering for all WMMA dot-operand loads from partitioned LDS buffers.

triton::gpu::LocalLoadOp op,
::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, Location loc,
LinearLayout cvt,
SmallVector<Value> &vals, // Input for stmatrix, output for ldmatrix
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this comment on purpose, as it seems as an accidental copy/paste from nv path

Comment on lines +232 to +243
// One ds_read_tr* instruction produces `fullTile.getInDimSize(kReg)`
// consecutive register values from a single LDS base pointer. We only
// select a partition once per instruction, so all of those register
// positions must map to the same partition. For a LinearLayout that holds
// iff the low log2(elemsPerInstr) register bases contribute 0 to
// kPartition. Bail out if not, so a generic lowering can take over.
const unsigned numInstrRegBits =
llvm::Log2_32(fullTile.getInDimSize(kReg));
for (unsigned pos = 0; pos < numInstrRegBits; ++pos) {
if (partitionLayout.getBasis(kReg, pos, kPartition) != 0)
return failure();
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably add same check in regular lowering path as well.

def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM, num_warps, M, N):
"""Test TDM async_load with PartitionedSharedLayout (global -> LDS)."""
@pytest.mark.parametrize("BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", _PARTITIONED_TDM_PARAMS)
def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repurposed this test so it can check end-to-end correctness of ds_transpose path with partitioned layout as well.

@plognjen
Copy link
Copy Markdown
Contributor Author

plognjen commented Apr 17, 2026

@lezcano @nzaghen can you take a look please?

Comment on lines -1499 to -1500
@pytest.mark.parametrize("num_warps", [4])
@pytest.mark.parametrize("M,N", [(256, 256)])
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to parametrize since it's a single value

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't carefully read the details, but the general structure looks reasonable to me. I'll let amd folks to have a proper look at the semantics

Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM; just a few impl nits.

f"partitionDim={PARTITION_DIM}, numPartitions={NUM_PARTITIONS}, numGroups={NUM_GROUPS}")


@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation only tests don't need to be gated on is_hip_gfx1250.


block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [num_warps, 1], [1, 0])
WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32])
OPERAND_LAYOUT: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe using DOT_RHS_LAYOUT to be clearer.

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto smemBase = smemObj.getBase();
SmallVector<Value> smemBases(smemObj.getBases().begin(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: llvm::to_vector(smemObj.getBases())?

// kPartition. Bail out if not, so a generic lowering can take over.
const unsigned numInstrRegBits =
llvm::Log2_32(fullTile.getInDimSize(kReg));
for (unsigned pos = 0; pos < numInstrRegBits; ++pos) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just checking partitionLayout.sublayoutIsZero({kReg}, {kPartition})?

Copy link
Copy Markdown
Contributor Author

@plognjen plognjen Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because this would check whole reg bases of partition layout, which would include repetitions. The point is that we want to check just first numInstrRegBits, which are number of register from fullTile layout, which is one instruction. It's fine for different repetitions (instructions) to be in different partitions, but we want to check if registers from single instruction are in different partition.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, if want it without looking at the bases, you can do that by reshaping kReg into two dimensions, one of dim numInstrRegBits and a different one and check the sublayoutIsZero there. But tbh I wouldn't rewrite it, the current solution seems alright.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah makes sense. Thanks for the explanation.

@plognjen plognjen force-pushed the ds_transpose_partitioned branch from 154d54f to 7e89e33 Compare April 20, 2026 11:04
@antiagainst antiagainst enabled auto-merge (squash) April 20, 2026 17:43
@antiagainst antiagainst merged commit ee5bc26 into triton-lang:main Apr 20, 2026
15 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants