[AMD][gfx1250] Support warp usage hints in TDM copy#10056
Conversation
Add an optional warp_bases attribute to AsyncTDMCopyGlobalToLocalOp that enables TDM warp specialization: only a subset of warps (activeWarps) perform TDM copies while the remaining warps get pred=0 (hardware no-op). - MLIR op definition: add warp_bases as OptionalAttr<DenseI64ArrayAttr> - Verifier: validate power-of-two, contiguous prefix, greedy distribution - Pybind: pass warp_bases from Python to MLIR - Python API: add warp_bases param to async_load with validation - Lowering: re-encode per-warp tile dims in TDM descriptor for activeWarps; emit layout_pred (warpId < activeWarps) ANDed with user pred - Example: add gemm_tdm_specialized_pipelined_warp_pipelined_kernel with --4warp-tdm CLI flag
…or warp specialization Replace the layout_pred approach (ANDing pred with warpId < activeWarps in fillTDMDescriptor) with a conditional branch in emitTDMIntrinsic that skips the entire TDM emission for inactive warps. This avoids computing dead descriptor values and ensures tensorcnt is not incremented for inactive warps. Also swap warp pipeline stage priorities in the specialized GEMM example (compute stage gets higher priority).
…cation for warp specialization" This reverts commit 17fffd9.
Add verifier negative tests (wrong size, non-contiguous prefix, greedy mismatch) and lowering tests (predication logic, partitioned layout instruction count) for the warp_bases attribute. Rename "warp specialization" to "partial TDM copy" in all TDM warp_bases-related comments and docs to better describe the mechanism.
activeWarps=0 exclusively means "warp_bases absent, all warps active." When warp_bases is present, activeWarps is at least 1 (2^0 for all-zero rows). This distinction matters for understanding the conditional logic in fillTDMDescriptor and emitTDMLoadStore.
| # Partial TDM copy variant: only a subset of warps issue TDM copies. | ||
| # Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth. |
There was a problem hiding this comment.
jus pass-by as I haven't looked at the PR in details:
I don't think we want to expose the concept of warps at Gluon level. I think there is a layering problem
There was a problem hiding this comment.
Yes, that was the main concept we tried hard not to break.
This change does not expose anything about warp, except for warp_bases, which is an element of the LinearLayout.
It only adds the ability to declare that the regions pointed to by warp0–3 and warp4–7 overlap, for example.
When regions are duplicated, tdm_copy for those warps are automatically disabled.
Does this still sound unacceptable?
There was a problem hiding this comment.
In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.
As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).
There was a problem hiding this comment.
ok, maybe I need to spend more time understand this. The TDM copy from global to shared memory so there shouldn't be any warp concept in the linear layout?
There was a problem hiding this comment.
For AMD TDM is warp level instruction. A single Gluon gfx1250.tdm.async_load op is under the hood done by all the warps collectively; each warp taking a slice of the tensor. Right now we are just having some heuristics to deduce the warp distribution. So warps are involved there; it's just implicit right now. This is making it explicit and controllable. No threads though. Even for NVIDIA I think we have warp involved and distributing to different warps?
There was a problem hiding this comment.
but this is leaking abstraction, the way those ops are distributed on warps is not meant to be exposed at the language level, the ops should be tile level ops
| # Partial TDM copy variant: only a subset of warps issue TDM copies. | ||
| # Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth. |
There was a problem hiding this comment.
In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.
As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).
| return tensor_descriptor(handle, shape, strides, type) | ||
|
|
||
|
|
||
| def _validate_warp_bases(warp_bases, block_shape, num_warps): |
There was a problem hiding this comment.
We should define this validation as a static method in the C++ op definition and expose it to Python via binding so that we can share the same logic in C++ op verifier and Python.
| : std::nullopt, | ||
| numDims); | ||
|
|
||
| // When partial TDM copy is active, the per-warp block shape differs from |
There was a problem hiding this comment.
Hmm. I'm not sure this is the natural way to implement it. If this is given from developer, we should be able to replace the logic in distributeTDMWarps with provided, and then rely on free variable to handle masking etc like gather/scatter?
- warp_bases now directly defines the TDM LinearLayout's "warp" sublayout; warpsPerCTA is derived from it. - Redundant-warp predication uses getFreeVariableMasks()["warp"], matching gather/scatter. - Structural validation is a static validateWarpBases method exposed via pybind, shared by the Gluon front-end and MLIR verifier. - Partitioned encodings with warp_bases must fit in a single TDM instruction (verifier-enforced).
# Conflicts: # third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp # third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp # third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h
42193f4 to
194abf8
Compare
|
Previous implementation was also misleading the concept, now |
| warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for partial TDM copy. | ||
| Each entry maps one bit of warpId to an element offset in the tensor coordinate space. | ||
| A zero basis means that bit contributes no offset (duplicate warp, gets pred=0). |
There was a problem hiding this comment.
I think we need a better description. Few things I think could help:
- Make it clear it is a performance hint and doesn't affect the semantic of the op (what gets written in the buffer is unchanged)
- Explain that it is used to decide what warps will participate in issuing the op.
- "A zero basis means that bit contributes no offset (duplicate warp, gets pred=0)." this sentence doesn't make sense to me. duplicate warp doesn't mean anything here. It seems like it is referencing linear layout but this is more confusing than helpful since from the language point of view this is completely unrelated to the layout of this op
| ArrayRef<unsigned> warpsPerCTA, | ||
| const LinearLayout &cgaLayout); | ||
| const LinearLayout &cgaLayout, | ||
| ArrayRef<int64_t> warpBases = {}); |
There was a problem hiding this comment.
why are bases int64? I don't think the name bases is very clear here. Each element is either 0 or 1 right? Is this supposed to be in log base?
| warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for partial TDM copy. | ||
| Each entry maps one bit of warpId to an element offset in the tensor coordinate space. |
There was a problem hiding this comment.
do we really need this kind of granularity? Can this just be a hint to decide which warp will emit the message?
| `warp_bases` is an optional attribute for partial TDM copy. It | ||
| directly defines the `"warp"` sublayout of the TDM LinearLayout: | ||
| each row maps one bit of warpId to an element offset in tensor | ||
| coordinates. A `[0, ..., 0]` basis means that bit of warpId | ||
| contributes no offset, so that warp is a redundant duplicate; | ||
| `getFreeVariableMasks()["warp"]` identifies those bits and the | ||
| lowering masks them with `(warpId & freeMask) == 0` so duplicate | ||
| warps issue a no-op TDM copy. The attribute is stored as a | ||
| flattened row-major array of shape `(log2(num_warps), ndim)`. | ||
|
|
||
| Structural constraints (enforced by the verifier, and equivalent | ||
| to "warp_bases describes a valid LinearLayout that tiles | ||
| block_shape with identical rectangular per-warp tiles"): | ||
| each row has at most one non-zero dim; along each dim the | ||
| non-zero entries must be `m_d, 2*m_d, 4*m_d, ...` in some order, | ||
| where `m_d = block_shape[d] / 2^k_d` and `k_d` is the number of | ||
| non-zero entries along dim `d` (so `2^k_d` divides `block_shape[d]`). | ||
|
|
||
| Additionally, with a PartitionedSharedEncoding the verifier requires | ||
| warpsPerCTA[partitionDim] (= 2^k along that dim) to be at least | ||
| numLogicalPieces so the copy fits in a single TDM instruction; the | ||
| multi-instruction slicing path does not support warp_bases because | ||
| the bases are stated in full-block coordinates. |
There was a problem hiding this comment.
I think I'm starting to understand thanks to discussion with @antiagainst but this seems overcomplicated for what I understand is needed.
Can we just decide on what warps will participate rather than representing this as such a fine grain control?
If we instead just add a bool list telling that some warp should be used wouldn't that solve the problem
Yeah, that's fair point @ThomasRaoux, I've discussed with @antiagainst and here's revised design that only uses bool list. Please note, the new design extends the idea to support loading from different source tensors with a single Partial/combined TDM copy via per-warp predCore ideaThe only IR/API change is allowing No
|
thanks for iterating on this, I think this is getting better. few questions:
Where would this happen? exposing warpId in TTGIR is not a good direction IMO
I agree on this is a good restriction So what I thought we could and would be simpler is just doing (I haven't thought about good naming): based on that codegen compute num_warp to use for the copy doing being able to do the optimization discussed in 2. would be good but it seems hard to represent on TTGIR, I wonder if this can just be a pattern match after lowering to llvm |
|
Thanks for the comments,
|
you do need to know statically the number of warps though, so fully dynamic value is not possible? |
That's right! I only thought what instruction can do but we need to figure out the layout. |
|
Now I believe everything is clear, I'll come back with the implementation. |
antiagainst
left a comment
There was a problem hiding this comment.
Much nicer; thanks Jungwook for revising the design and impl!
|
|
||
|
|
||
| @gluon.jit | ||
| def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, # |
There was a problem hiding this comment.
I'd call it tdm_predicated to avoid being confusing as this is different from warp specialization.
|
Now this is ready for review, @ThomasRaoux Tensor-desc lowering only builds a base TDM descriptor from tensor metadata: base pointer, shape, strides, padding, etc. The final hardware descriptor is completed later during each async_tdm_copy lowering, where op-local fields like pred, LDS address, barrier, destination layout/partitioning, and tile_dim* are known. warp_used_hint is ignored by the base descriptor and used only when completing the per-instruction hardware descriptor fields, especially tile_dim* and pred. I updated the related comments to make this distinction clearer. |
antiagainst
left a comment
There was a problem hiding this comment.
Thanks @jungpark-mlir for iterating on it! This looks good to me now. Please wait for Thomas to take another look.
|
@jungpark-mlir please also update the pull request message to reflect latest design and impl. |
ThomasRaoux
left a comment
There was a problem hiding this comment.
That mostly looks good to me but I think it would be great if we could simplify a bit the conditions we support for this hint, it's really hard to wrap my head around it.
| // Axis-aligned check. Anchor at i0 = lsb(hint) and OR the shifted | ||
| // warp indices: `support` is the bits that vary across the active | ||
| // set. Legal iff popcount(support) == log2(K) -- pigeonhole forces | ||
| // the K shifted indices to hit every subset of `support`, i.e. the | ||
| // active set is selectable by a single mask check. | ||
| unsigned i0 = llvm::countr_zero(hint); | ||
| uint32_t support = 0; | ||
| for (uint32_t mask = hint; mask != 0; mask &= mask - 1) { | ||
| unsigned w = llvm::countr_zero(mask); | ||
| support |= static_cast<uint32_t>(w ^ i0); | ||
| } | ||
| unsigned logK = llvm::Log2_32(K); | ||
| unsigned spanned = static_cast<unsigned>(llvm::popcount(support)); | ||
| if (spanned != logK) | ||
| return op.emitOpError("warp_used_hint = ") | ||
| << llvm::formatv("{0:x}", hint) << " is not axis-aligned: K = " << K | ||
| << " active warps span " << spanned | ||
| << " warpId bit positions, but an axis-aligned hint " | ||
| << "spans exactly log2(K) = " << logK; | ||
|
|
There was a problem hiding this comment.
The set of restrictions seems very hard to follow, is there anyway to make this simpler or have better explanation, this seems like something super specific to a set of kernel you have mind and will surely be really hard to understand for someone that doesn't understand both the compiler and low level isa. (and tbh I read this a bunch of times and can't understand at all what that means). Maybe it is that the term are confusing?
There was a problem hiding this comment.
@ThomasRaoux Right, that's fair point, I've just mechanically treated the rule and agree it's confusing.
Main reason could be, I tried to describe the rule in terms of individual warpId bits as verifier does but users mainly think in terms of the hint mask shape. Here's a new version of the description explains the allowed bit pattern. Please take a look, I'll revise the comments in the code based on this if this makes better sense.
The hint must select a regular pattern of 1 bits (active warps) and 0 bits
(inactive warps). A valid mask is built from power-of-two-sized groups of 1s
placed within a power-of-two-sized block at an aligned offset. The same 1/0
pattern may appear once or repeat uniformly across the mask.
Currently warp_used_hint is not supported for the non-power-of-two numWarps
like 12.
For example, with num_warps = 8, valid masks include:
0b00001111: one contiguous group of four1s.0b11110000: one contiguous group of four1s at a different offset.0b01010101: single-1groups repeated every 2 bits.0b10101010: the same single-1group pattern at a different offset.0b00110011: groups of two1s repeated every 4 bits.
Irregular masks that hand-pick unrelated 1 bits, such as 0b01101001,
0b00011011, or 0b01111000, are rejected. In all valid cases,
K = popcount(warp_used_hint) is a power of two.
For num_warps > 8, the same rule applies, and patterns may also be nested:
a smaller valid pattern can itself be repeated at a larger power-of-two stride.
For example, with num_warps = 16:
0x0505 = 0000_0101_0000_0101
inner: single-1groups every 2 bits inside an 8-bit half.
outer: that 8-bit pattern repeated every 16 bits.
selects warps {0, 2, 8, 10}.
There was a problem hiding this comment.
Perhaps you want to describe it as a linear vector subspace. These are classified by masks 0 <= mask < num_warps. The ones in a mask represent the bits that may move to describe the subspace. For example, the subspace spanned by 0b101 and 8 warps are 0, 1, 4, 5.
All your examples above for 8 warps that have a 1 at the LSB are linear subspaces, so I gather that, if you just care about the linear ones this would be a good representation.
Would these be enough (the num_warps > 8 ones don't fit this pattern, but yeah.
If you also care about the affine subspaces, these are represented by two masks m, a such that m & a == 0. For example, the mask 0b101 admits a = 0x0 and a = 0x2. The second one representing the linear subspace 2, 3, 6, 7.
There was a problem hiding this comment.
Agreed, current rule restrict the pattern under LL. Once we confirm the user API, I'll revise the documentation for verifier in this way, i.e., mask on the warpID for example 0b101 means fix second bit as zero and allow all other bits free to get allowed warpID, or even simpler for verifier.
| WARP_BASES = [(0, 1), (1, 0), (2, 0)] | ||
|
|
||
| # 4-warp partial TDM copy: warps 0-3 issue, warps 4-7 are no-ops. | ||
| tdm_warp_used_hint = 0b00001111 |
There was a problem hiding this comment.
For my understanding, how is a user expected to chose this mask? What harware factors are influencing this and how useful is it to tune this with different masks.
There was a problem hiding this comment.
The actual HW instruction for a TDM copy uses a per-warp descriptor. Today we partition the tensor evenly across all warps in the workgroup, but we want to enable other configurations. Examples include a partial copy (only 4 of 8 warps participate) and split copies where half the warps load from one tensor and the other half from another.
One use case is warp-pipelining, where the upper-half and lower-half warp groups run in a pipelined fashion. In that setup we want only the upper half (the leading group) to issue the copy in the current stage, instead of having the lower half issue a copy that belongs to a later stage.
The follow-up PR will combine multiple predicated copies into a single instruction when possible. The combined copy can carry a different descriptor per warp, for example loading A and B in one instruction.
The mask is determined by the use case rather than chosen for performance. For a plain TDM copy, the attribute should be omitted so that all warps participate. In the long run, end users should not need to set this hint manually. Ideally a future compiler pass would derive it when beneficial.
There was a problem hiding this comment.
I'm leaning towards this shouldn't be exposed in gluon. If the compiler is able to merge the loads of A and B, then it could also evenly divide the merged load across warps. I don't see the point of the user deciding the exact warp mapping, unless there is some performance reason to.
For the warp-pipelined case, it sounds like there's only one correct way to handle it and the compiler could set the mask from the warp-pipelining pass. I'm also curious if only one warp group is able to handle the load, then how are load and async_copy handled? Do they have layouts that only use 1 warp group?
There was a problem hiding this comment.
What this PR does is
- user to define warp mapping by predication
- compiler to get the warp number actively participating the copy
- compiler to rewrite the layout from the original layout -> performed by active warps only
Technically, user doesn't need to know the optimal mapping and compiler can figure that out.
But I'm still not sure we know the optimal number of warps for any given tdm copy use case.
Another option might be still using the explicit hint but just to specify the number of warps that participate the copy instead of asking user to specify exactly which warps to perform. What do you think?
There was a problem hiding this comment.
But I'm still not sure we know the optimal number of warps for any given tdm copy use case.
What are the factors that make fewer warps better? Assuming you're doing a single load, that can't be merged with anything else and isn't warp-pipelined. Or is it more that you want the user to be able to control if loads are fused into a single instruction?
I understand what this PR does mechanically, but want to understand how a user is expected to use this. Currently it sounds like it is more of an implementation detail.
Also you never answered about async_copy in warp-pipelined kernels.
I'm also curious if only one warp group is able to handle the load, then how are load and async_copy handled? Do they have layouts that only use 1 warp group?
There was a problem hiding this comment.
Assuming you're doing a single load, that can't be merged with anything else and isn't warp-pipelined.
One case is, even though we don't explicitly use warp-pipeline, when we have more than 1 warps per SIMD (e.g. numWarps=8), execution of two warps are serialized on a SIMD. We might expect the same benefit as warp-pipeline case from performing tdm copy on a first warps per SIMD.
There was a problem hiding this comment.
I understand what this PR does mechanically, but want to understand how a user is expected to use this. Currently it sounds like it is more of an implementation detail.
Yeah, that's fair. I'm open to revise the user interface let's discuss more detail.
There was a problem hiding this comment.
Or is it more that you want the user to be able to control if loads are fused into a single instruction?
And yes, this is currently the main reason.
There was a problem hiding this comment.
Sorry, I meant async_copy.global_load_to_shared so a non-tdm async copy. In that case you have to pass a tensor of pointers, so the layout of the tensor should control where the copy is issued from. How does that work in a warp-pipelined kernel?
But yeah, for tdm if there are more warps than there are SIMD units we could clamp the number of warps used to the SIMD unit size in the lowering. If that's always the best thing to do, then there's no need to giving the user control over it.
I wonder is the tdm unit per-warp so the code must be careful to balance the load between warps? That could be a good justification for hints, as only the user knows if say one load has structurally higher latency for example.
There was a problem hiding this comment.
Oh, OK. Haven't really considered to do the same using the async_copy.
I'll come back with a better answer for the last question.
# Conflicts: # python/triton/experimental/gluon/language/amd/gfx1250/tdm.py # third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp # third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h # third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp
ThomasRaoux
left a comment
There was a problem hiding this comment.
couple more nits LG otherwise
| ) | ||
|
|
||
| expected = a_cpu + b_cpu | ||
| torch.testing.assert_close(c.cpu(), expected, atol=1e-3, rtol=1e-3) |
There was a problem hiding this comment.
nit: it's a bit weird to have precision tolerance there, maybe use integers for a more robust test
There was a problem hiding this comment.
Right, moved to int
This PR adds an optional
warp_used_hintattribute toAsyncTDMCopyGlobalToLocalOpthat enables partial TDM copy : only the selected subset of warps perform useful TDM loads while the rest getpred=0in their descriptor (hardware virtually no-op, instruction still issued but no data moved).The attribute is an
i32bitmask: bitnselects warpn. The hint is a performance hint only; it does not change the logical copy or the data written to shared memory. For example, withnum_warps=8,warp_used_hint = 0b00001111means warps 0-3 perform the copy and warps 4-7 are predicated off. The verifier requires the active warps to follow a regular axis-aligned bit pattern so lowering can derive a power-of-two active warp count and reuse the existing LinearLayout/free-variable machinery for offset and predicate generation.During lowering, the tensor descriptor is first represented as a base TDM descriptor containing tensor metadata (base pointer, shape, strides, padding). The final per-instruction hardware descriptor is completed later when lowering each
async_tdm_copy, where op-local fields such aspred, LDS address, barrier, destination layout/partitioning, andtile_dim*are known.warp_used_hintis ignored earlier by the base descriptor and used only when completing those per-instruction hardware descriptor fields, especiallytile_dim*andpred.For a hint with
K = popcount(warp_used_hint)active warps,fillTDMDescriptorre-encodes per-warp tile dimensions asblock / Kso the selected warps still cover the same user-visible block in one TDM instruction. This is useful whennum_warpsexceeds what is needed for the copy.The PR includes an example, verifier tests, lowering lit tests, and Python coverage.