[AMD] Enable ds_read_tr* lowering for PartitionedSharedEncodingAttr#10062
Conversation
| triton::gpu::LocalLoadOp op, | ||
| ::triton::AMD::TargetInfo::LDSTransLoadParams ldsParams, Location loc, | ||
| LinearLayout cvt, | ||
| SmallVector<Value> &vals, // Input for stmatrix, output for ldmatrix |
There was a problem hiding this comment.
removed this comment on purpose, as it seems as an accidental copy/paste from nv path
| // One ds_read_tr* instruction produces `fullTile.getInDimSize(kReg)` | ||
| // consecutive register values from a single LDS base pointer. We only | ||
| // select a partition once per instruction, so all of those register | ||
| // positions must map to the same partition. For a LinearLayout that holds | ||
| // iff the low log2(elemsPerInstr) register bases contribute 0 to | ||
| // kPartition. Bail out if not, so a generic lowering can take over. | ||
| const unsigned numInstrRegBits = | ||
| llvm::Log2_32(fullTile.getInDimSize(kReg)); | ||
| for (unsigned pos = 0; pos < numInstrRegBits; ++pos) { | ||
| if (partitionLayout.getBasis(kReg, pos, kPartition) != 0) | ||
| return failure(); | ||
| } |
There was a problem hiding this comment.
I should probably add same check in regular lowering path as well.
| def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM, num_warps, M, N): | ||
| """Test TDM async_load with PartitionedSharedLayout (global -> LDS).""" | ||
| @pytest.mark.parametrize("BLOCK_M,BLOCK_N,NUM_PARTITIONS,NUM_GROUPS,PARTITION_DIM", _PARTITIONED_TDM_PARAMS) | ||
| def test_runtime_partitioned_tdm_load(BLOCK_M, BLOCK_N, NUM_PARTITIONS, NUM_GROUPS, PARTITION_DIM): |
There was a problem hiding this comment.
Repurposed this test so it can check end-to-end correctness of ds_transpose path with partitioned layout as well.
| @pytest.mark.parametrize("num_warps", [4]) | ||
| @pytest.mark.parametrize("M,N", [(256, 256)]) |
There was a problem hiding this comment.
No need to parametrize since it's a single value
lezcano
left a comment
There was a problem hiding this comment.
I didn't carefully read the details, but the general structure looks reasonable to me. I'll let amd folks to have a proper look at the semantics
antiagainst
left a comment
There was a problem hiding this comment.
Overall LGTM; just a few impl nits.
| f"partitionDim={PARTITION_DIM}, numPartitions={NUM_PARTITIONS}, numGroups={NUM_GROUPS}") | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") |
There was a problem hiding this comment.
Compilation only tests don't need to be gated on is_hip_gfx1250.
|
|
||
| block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [num_warps, 1], [1, 0]) | ||
| WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32]) | ||
| OPERAND_LAYOUT: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) |
There was a problem hiding this comment.
Nit: maybe using DOT_RHS_LAYOUT to be clearer.
| auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), | ||
| llvmElemTy, rewriter); | ||
| auto smemBase = smemObj.getBase(); | ||
| SmallVector<Value> smemBases(smemObj.getBases().begin(), |
There was a problem hiding this comment.
Nit: llvm::to_vector(smemObj.getBases())?
| // kPartition. Bail out if not, so a generic lowering can take over. | ||
| const unsigned numInstrRegBits = | ||
| llvm::Log2_32(fullTile.getInDimSize(kReg)); | ||
| for (unsigned pos = 0; pos < numInstrRegBits; ++pos) { |
There was a problem hiding this comment.
This is just checking partitionLayout.sublayoutIsZero({kReg}, {kPartition})?
There was a problem hiding this comment.
no, because this would check whole reg bases of partition layout, which would include repetitions. The point is that we want to check just first numInstrRegBits, which are number of register from fullTile layout, which is one instruction. It's fine for different repetitions (instructions) to be in different partitions, but we want to check if registers from single instruction are in different partition.
There was a problem hiding this comment.
For reference, if want it without looking at the bases, you can do that by reshaping kReg into two dimensions, one of dim numInstrRegBits and a different one and check the sublayoutIsZero there. But tbh I wouldn't rewrite it, the current solution seems alright.
There was a problem hiding this comment.
yeah makes sense. Thanks for the explanation.
154d54f to
7e89e33
Compare
Extends the AMD ds_read_tr* local-load lowering to accept PartitionedSharedEncodingAttr as the source encoding.
Previously the pattern bailed out as soon as it saw a partitioned shared encoding, forcing a slower generic
local-load lowering for all WMMA dot-operand loads from partitioned LDS buffers.