Skip to content

[AMD][gfx1250] Support warp usage hints in TDM copy#10056

Merged
antiagainst merged 50 commits into
triton-lang:mainfrom
jungpark-mlir:tdm-special
May 13, 2026
Merged

[AMD][gfx1250] Support warp usage hints in TDM copy#10056
antiagainst merged 50 commits into
triton-lang:mainfrom
jungpark-mlir:tdm-special

Conversation

@jungpark-mlir
Copy link
Copy Markdown
Contributor

@jungpark-mlir jungpark-mlir commented Apr 16, 2026

This PR adds an optional warp_used_hint attribute to AsyncTDMCopyGlobalToLocalOp that enables partial TDM copy : only the selected subset of warps perform useful TDM loads while the rest get pred=0 in their descriptor (hardware virtually no-op, instruction still issued but no data moved).

The attribute is an i32 bitmask: bit n selects warp n. The hint is a performance hint only; it does not change the logical copy or the data written to shared memory. For example, with num_warps=8, warp_used_hint = 0b00001111 means warps 0-3 perform the copy and warps 4-7 are predicated off. The verifier requires the active warps to follow a regular axis-aligned bit pattern so lowering can derive a power-of-two active warp count and reuse the existing LinearLayout/free-variable machinery for offset and predicate generation.

During lowering, the tensor descriptor is first represented as a base TDM descriptor containing tensor metadata (base pointer, shape, strides, padding). The final per-instruction hardware descriptor is completed later when lowering each async_tdm_copy, where op-local fields such as pred, LDS address, barrier, destination layout/partitioning, and tile_dim* are known. warp_used_hint is ignored earlier by the base descriptor and used only when completing those per-instruction hardware descriptor fields, especially tile_dim* and pred.

For a hint with K = popcount(warp_used_hint) active warps, fillTDMDescriptor re-encodes per-warp tile dimensions as block / K so the selected warps still cover the same user-visible block in one TDM instruction. This is useful when num_warps exceeds what is needed for the copy.
The PR includes an example, verifier tests, lowering lit tests, and Python coverage.

Add an optional warp_bases attribute to AsyncTDMCopyGlobalToLocalOp that
enables TDM warp specialization: only a subset of warps (activeWarps)
perform TDM copies while the remaining warps get pred=0 (hardware no-op).

- MLIR op definition: add warp_bases as OptionalAttr<DenseI64ArrayAttr>
- Verifier: validate power-of-two, contiguous prefix, greedy distribution
- Pybind: pass warp_bases from Python to MLIR
- Python API: add warp_bases param to async_load with validation
- Lowering: re-encode per-warp tile dims in TDM descriptor for
  activeWarps; emit layout_pred (warpId < activeWarps) ANDed with
  user pred
- Example: add gemm_tdm_specialized_pipelined_warp_pipelined_kernel
  with --4warp-tdm CLI flag
…or warp specialization

Replace the layout_pred approach (ANDing pred with warpId < activeWarps in
fillTDMDescriptor) with a conditional branch in emitTDMIntrinsic that skips
the entire TDM emission for inactive warps. This avoids computing dead
descriptor values and ensures tensorcnt is not incremented for inactive warps.

Also swap warp pipeline stage priorities in the specialized GEMM example
(compute stage gets higher priority).
…cation for warp specialization"

This reverts commit 17fffd9.
Add verifier negative tests (wrong size, non-contiguous prefix, greedy
mismatch) and lowering tests (predication logic, partitioned layout
instruction count) for the warp_bases attribute.

Rename "warp specialization" to "partial TDM copy" in all TDM
warp_bases-related comments and docs to better describe the mechanism.
activeWarps=0 exclusively means "warp_bases absent, all warps active."
When warp_bases is present, activeWarps is at least 1 (2^0 for all-zero
rows). This distinction matters for understanding the conditional logic
in fillTDMDescriptor and emitTDMLoadStore.
Comment on lines +103 to +104
# Partial TDM copy variant: only a subset of warps issue TDM copies.
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jus pass-by as I haven't looked at the PR in details:
I don't think we want to expose the concept of warps at Gluon level. I think there is a layering problem

Copy link
Copy Markdown
Contributor Author

@jungpark-mlir jungpark-mlir Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the main concept we tried hard not to break.
This change does not expose anything about warp, except for warp_bases, which is an element of the LinearLayout.
It only adds the ability to declare that the regions pointed to by warp0–3 and warp4–7 overlap, for example.
When regions are duplicated, tdm_copy for those warps are automatically disabled.

Does this still sound unacceptable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.

As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, maybe I need to spend more time understand this. The TDM copy from global to shared memory so there shouldn't be any warp concept in the linear layout?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AMD TDM is warp level instruction. A single Gluon gfx1250.tdm.async_load op is under the hood done by all the warps collectively; each warp taking a slice of the tensor. Right now we are just having some heuristics to deduce the warp distribution. So warps are involved there; it's just implicit right now. This is making it explicit and controllable. No threads though. Even for NVIDIA I think we have warp involved and distributing to different warps?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this is leaking abstraction, the way those ops are distributed on warps is not meant to be exposed at the language level, the ops should be tile level ops

Comment on lines +103 to +104
# Partial TDM copy variant: only a subset of warps issue TDM copies.
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.

As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).

return tensor_descriptor(handle, shape, strides, type)


def _validate_warp_bases(warp_bases, block_shape, num_warps):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define this validation as a static method in the C++ op definition and expose it to Python via binding so that we can share the same logic in C++ op verifier and Python.

: std::nullopt,
numDims);

// When partial TDM copy is active, the per-warp block shape differs from
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I'm not sure this is the natural way to implement it. If this is given from developer, we should be able to replace the logic in distributeTDMWarps with provided, and then rely on free variable to handle masking etc like gather/scatter?

- warp_bases now directly defines the TDM LinearLayout's "warp"
  sublayout; warpsPerCTA is derived from it.
- Redundant-warp predication uses getFreeVariableMasks()["warp"],
  matching gather/scatter.
- Structural validation is a static validateWarpBases method exposed
  via pybind, shared by the Gluon front-end and MLIR verifier.
- Partitioned encodings with warp_bases must fit in a single TDM
  instruction (verifier-enforced).
# Conflicts:
#	third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
#	third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp
#	third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h
@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

Previous implementation was also misleading the concept, now warp_bases is directly used for the tensor's linear layout and hope everything makes better sense.

@jungpark-mlir jungpark-mlir marked this pull request as ready for review April 21, 2026 15:53
@jungpark-mlir jungpark-mlir changed the title [WIP][AMD][TDM] Add partial TDM copy support via warp_bases [AMD][TDM] Add partial TDM copy support via warp_bases Apr 21, 2026
Comment on lines +178 to +180
warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for partial TDM copy.
Each entry maps one bit of warpId to an element offset in the tensor coordinate space.
A zero basis means that bit contributes no offset (duplicate warp, gets pred=0).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a better description. Few things I think could help:

  • Make it clear it is a performance hint and doesn't affect the semantic of the op (what gets written in the buffer is unchanged)
  • Explain that it is used to decide what warps will participate in issuing the op.
  • "A zero basis means that bit contributes no offset (duplicate warp, gets pred=0)." this sentence doesn't make sense to me. duplicate warp doesn't mean anything here. It seems like it is referencing linear layout but this is more confusing than helpful since from the language point of view this is completely unrelated to the layout of this op

ArrayRef<unsigned> warpsPerCTA,
const LinearLayout &cgaLayout);
const LinearLayout &cgaLayout,
ArrayRef<int64_t> warpBases = {});
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are bases int64? I don't think the name bases is very clear here. Each element is either 0 or 1 right? Is this supposed to be in log base?

Comment on lines +178 to +179
warp_bases (List[List[int]], optional): Per-bit warp-to-offset mapping for partial TDM copy.
Each entry maps one bit of warpId to an element offset in the tensor coordinate space.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this kind of granularity? Can this just be a hint to decide which warp will emit the message?

Comment on lines +754 to +776
`warp_bases` is an optional attribute for partial TDM copy. It
directly defines the `"warp"` sublayout of the TDM LinearLayout:
each row maps one bit of warpId to an element offset in tensor
coordinates. A `[0, ..., 0]` basis means that bit of warpId
contributes no offset, so that warp is a redundant duplicate;
`getFreeVariableMasks()["warp"]` identifies those bits and the
lowering masks them with `(warpId & freeMask) == 0` so duplicate
warps issue a no-op TDM copy. The attribute is stored as a
flattened row-major array of shape `(log2(num_warps), ndim)`.

Structural constraints (enforced by the verifier, and equivalent
to "warp_bases describes a valid LinearLayout that tiles
block_shape with identical rectangular per-warp tiles"):
each row has at most one non-zero dim; along each dim the
non-zero entries must be `m_d, 2*m_d, 4*m_d, ...` in some order,
where `m_d = block_shape[d] / 2^k_d` and `k_d` is the number of
non-zero entries along dim `d` (so `2^k_d` divides `block_shape[d]`).

Additionally, with a PartitionedSharedEncoding the verifier requires
warpsPerCTA[partitionDim] (= 2^k along that dim) to be at least
numLogicalPieces so the copy fits in a single TDM instruction; the
multi-instruction slicing path does not support warp_bases because
the bases are stated in full-block coordinates.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm starting to understand thanks to discussion with @antiagainst but this seems overcomplicated for what I understand is needed.
Can we just decide on what warps will participate rather than representing this as such a fine grain control?
If we instead just add a bool list telling that some warp should be used wouldn't that solve the problem

@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

Can we just decide on what warps will participate rather than representing this as such a fine grain control?
If we instead just add a bool list telling that some warp should be used wouldn't that solve the problem

Yeah, that's fair point @ThomasRaoux, I've discussed with @antiagainst and here's revised design that only uses bool list. Please note, the new design extends the idea to support loading from different source tensors with a single merged copy op.

Partial/combined TDM copy via per-warp pred

Core idea

The only IR/API change is allowing pred to be either a scalar (as today) or a tensor<num_warps, i1> (one bit per warp). Partial copy and multi-destination copies all fall out of that plus an optimizer pass that merges compatible copy ops into a single TDM instruction.

No warp_bases.

pred semantics

  • Scalar i32 (default, unchanged). Uniform predicate across all warps. Backward-compatible with every existing kernel; one-line ops like tdm.async_load(d, [o0,o1], s, pred=1) stay identical.
  • tensor<num_warps, i1>. Bit i is warp i's predicate. Warp with bit = 0 issues a TDM instruction that is a hardware no-op (descriptor's pred field = 0). Scalar and tensor forms are interchangeable wherever the user wants one or the other.

1) partial copy

# First 4 of 8 warps do the load; others are HW no-ops.
mask = ttgl.arange(0, num_warps, layout=L) < 4       # 11110000
tdm.async_load(d, [o0, o1], s, pred=mask, mbarrier=b)

One op, one TDM intrinsic. Inactive warps predicate off identically to today's warp_bases case.

2) merging multiple copies

Two async_load ops are compatible when:

  • Their pred masks are disjoint: (mask_a & mask_b) == 0 (no warp in both).
  • They share mbarrier (or both have none).
  • No async_wait/intervening TDM op between them.
  • LDS targets don't overlap.

Compatible ops fuse into one TDM intrinsic via per-wave descriptor select:

desc_a = build_descriptor(src_a, offsets_a, dest_a, pred = mask_a[warpId])
desc_b = build_descriptor(src_b, offsets_b, dest_b, pred = mask_b[warpId])
desc   = select(mask_a[warpId], desc_a, desc_b)      # per-dword SGPR select
tdm_copy desc

src, offsets, dest don't need to match, they get select-fused along with pred. Same instruction count as today's warp_bases-based partial copy; strictly better than issuing two predicated intrinsics.

How popcount(pred) shapes the copy

For a tensor<num_warps, i1> pred, the number of set bits K = popcount(pred) is what determines tile partitioning:

  • K must be a power of 2 and ≤ num_warps.
  • warpsPerCTA = greedy(src.block_shape, K) exactly the default distribution, but over K warps instead of num_warps. Each active warp loads block_shape / warpsPerCTA elements.
  • The num_warps - K unset bits map to the layout's free variables over kWarp, so inactive warps get descriptor pred = 0 (HW no-op) automatically ; no extra masking code.

Scalar pred is equivalent to K = num_warps (full distribution), preserving today's behavior exactly.

After merging N compatible ops into one intrinsic, popcount is applied per contributing op to set each sub-descriptor's warpsPerCTA; the merged intrinsic then select-fuses those per-op descriptors by warpId. So the count-of-ones rule is local to each logical copy, not the merged whole.

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

Can we just decide on what warps will participate rather than representing this as such a fine grain control?
If we instead just add a bool list telling that some warp should be used wouldn't that solve the problem

Yeah, that's fair point @ThomasRaoux, I've discussed with @antiagainst and here's revised design that only uses bool list. Please note, the new design extends the idea to support loading from different source tensors with a single merged copy op.

Partial/combined TDM copy via per-warp pred

Core idea

The only IR/API change is allowing pred to be either a scalar (as today) or a tensor<num_warps, i1> (one bit per warp). Partial copy and multi-destination copies all fall out of that plus an optimizer pass that merges compatible copy ops into a single TDM instruction.

No warp_bases.

pred semantics

  • Scalar i32 (default, unchanged). Uniform predicate across all warps. Backward-compatible with every existing kernel; one-line ops like tdm.async_load(d, [o0,o1], s, pred=1) stay identical.
  • tensor<num_warps, i1>. Bit i is warp i's predicate. Warp with bit = 0 issues a TDM instruction that is a hardware no-op (descriptor's pred field = 0). Scalar and tensor forms are interchangeable wherever the user wants one or the other.

1) partial copy

# First 4 of 8 warps do the load; others are HW no-ops.
mask = ttgl.arange(0, num_warps, layout=L) < 4       # 11110000
tdm.async_load(d, [o0, o1], s, pred=mask, mbarrier=b)

One op, one TDM intrinsic. Inactive warps predicate off identically to today's warp_bases case.

2) merging multiple copies

Two async_load ops are compatible when:

  • Their pred masks are disjoint: (mask_a & mask_b) == 0 (no warp in both).
  • They share mbarrier (or both have none).
  • No async_wait/intervening TDM op between them.
  • LDS targets don't overlap.

Compatible ops fuse into one TDM intrinsic via per-wave descriptor select:

desc_a = build_descriptor(src_a, offsets_a, dest_a, pred = mask_a[warpId])
desc_b = build_descriptor(src_b, offsets_b, dest_b, pred = mask_b[warpId])
desc   = select(mask_a[warpId], desc_a, desc_b)      # per-dword SGPR select
tdm_copy desc

src, offsets, dest don't need to match, they get select-fused along with pred. Same instruction count as today's warp_bases-based partial copy; strictly better than issuing two predicated intrinsics.

How popcount(pred) shapes the copy

For a tensor<num_warps, i1> pred, the number of set bits K = popcount(pred) is what determines tile partitioning:

  • K must be a power of 2 and ≤ num_warps.
  • warpsPerCTA = greedy(src.block_shape, K) exactly the default distribution, but over K warps instead of num_warps. Each active warp loads block_shape / warpsPerCTA elements.
  • The num_warps - K unset bits map to the layout's free variables over kWarp, so inactive warps get descriptor pred = 0 (HW no-op) automatically ; no extra masking code.

Scalar pred is equivalent to K = num_warps (full distribution), preserving today's behavior exactly.

After merging N compatible ops into one intrinsic, popcount is applied per contributing op to set each sub-descriptor's warpsPerCTA; the merged intrinsic then select-fuses those per-op descriptors by warpId. So the count-of-ones rule is local to each logical copy, not the merged whole.

thanks for iterating on this, I think this is getting better. few questions:

  1. tensor<num_warps, i1> the bool list would have to be static though? How would you figure out how to codegen if it is just a value? The semantic of such tensor also seems a little unclear to me
  2. About this:
desc_a = build_descriptor(src_a, offsets_a, dest_a, pred = mask_a[warpId])
desc_b = build_descriptor(src_b, offsets_b, dest_b, pred = mask_b[warpId])
desc   = select(mask_a[warpId], desc_a, desc_b)      # per-dword SGPR select
tdm_copy desc

Where would this happen? exposing warpId in TTGIR is not a good direction IMO
3.

For a tensor<num_warps, i1> pred, the number of set bits K = popcount(pred) is what determines tile partitioning:

  • K must be a power of 2 and ≤ num_warps.

I agree on this is a good restriction

So what I thought we could and would be simpler is just doing (I haven't thought about good naming):

tdm.async_load(d, [o0,o1], s, pred=1, warp_used_hint=0x5)

based on that codegen compute num_warp to use for the copy doing popcount(warp_used_hint) then the warpID used in code gen can just be remapping of physical warpID to the one you want to use.

being able to do the optimization discussed in 2. would be good but it seems hard to represent on TTGIR, I wonder if this can just be a pattern match after lowering to llvm

@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

Thanks for the comments,

  1. HW could work with dynamic value but that's fair question that we actually need it, keep pred as dynamic scalar together with static warp_used_hint sounds good to me, any idea @antiagainst?
  2. Sorry, I mixed up what happens during the lowering to llvm, none of warpID or select happens in the ttgir level. I'm trying to figure out just one more thing to determine we can hide everything from the user; Need to confirm we can correctly calculate wait_count after the merge.

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

  1. HW could work with dynamic value but that's fair question that we actually need it, keep pred as dynamic scalar together with static warp_used_hint sounds good to me, any idea @antiagainst?

you do need to know statically the number of warps though, so fully dynamic value is not possible?

@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

you do need to know statically the number of warps though, so fully dynamic value is not possible?

That's right! I only thought what instruction can do but we need to figure out the layout.

@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

Now I believe everything is clear, I'll come back with the implementation.

@jungpark-mlir jungpark-mlir changed the title [AMD][TDM] Add partial TDM copy support via warp_bases [AMD][TDM] Add support for partial/merged TDM copy Apr 24, 2026
Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer; thanks Jungwook for revising the design and impl!

Comment thread third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td Outdated
Comment thread python/triton/experimental/gluon/language/amd/gfx1250/tdm.py
Comment thread third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp Outdated
Comment thread third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Comment thread third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp Outdated
Comment thread third_party/amd/python/test/test_tdm_copy.py Outdated
Comment thread third_party/amd/python/test/test_tdm_copy.py Outdated
Comment thread third_party/amd/python/test/test_tdm_copy.py


@gluon.jit
def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd call it tdm_predicated to avoid being confusing as this is different from warp specialization.

Comment thread third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py Outdated
@jungpark-mlir jungpark-mlir marked this pull request as ready for review May 3, 2026 20:09
@jungpark-mlir
Copy link
Copy Markdown
Contributor Author

Now this is ready for review, @ThomasRaoux
I think we found one thing that might make it confusing previously : our descriptor construction is split into two stages.

Tensor-desc lowering only builds a base TDM descriptor from tensor metadata: base pointer, shape, strides, padding, etc. The final hardware descriptor is completed later during each async_tdm_copy lowering, where op-local fields like pred, LDS address, barrier, destination layout/partitioning, and tile_dim* are known.

warp_used_hint is ignored by the base descriptor and used only when completing the per-instruction hardware descriptor fields, especially tile_dim* and pred.

I updated the related comments to make this distinction clearer.

Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jungpark-mlir for iterating on it! This looks good to me now. Please wait for Thomas to take another look.

@antiagainst antiagainst changed the title [AMD][TDM] Add support for partial TDM copy [AMD][gfx1250] Support warp usage hints in TDM copy May 4, 2026
@antiagainst
Copy link
Copy Markdown
Member

@jungpark-mlir please also update the pull request message to reflect latest design and impl.

Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That mostly looks good to me but I think it would be great if we could simplify a bit the conditions we support for this hint, it's really hard to wrap my head around it.

Comment on lines +620 to +639
// Axis-aligned check. Anchor at i0 = lsb(hint) and OR the shifted
// warp indices: `support` is the bits that vary across the active
// set. Legal iff popcount(support) == log2(K) -- pigeonhole forces
// the K shifted indices to hit every subset of `support`, i.e. the
// active set is selectable by a single mask check.
unsigned i0 = llvm::countr_zero(hint);
uint32_t support = 0;
for (uint32_t mask = hint; mask != 0; mask &= mask - 1) {
unsigned w = llvm::countr_zero(mask);
support |= static_cast<uint32_t>(w ^ i0);
}
unsigned logK = llvm::Log2_32(K);
unsigned spanned = static_cast<unsigned>(llvm::popcount(support));
if (spanned != logK)
return op.emitOpError("warp_used_hint = ")
<< llvm::formatv("{0:x}", hint) << " is not axis-aligned: K = " << K
<< " active warps span " << spanned
<< " warpId bit positions, but an axis-aligned hint "
<< "spans exactly log2(K) = " << logK;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set of restrictions seems very hard to follow, is there anyway to make this simpler or have better explanation, this seems like something super specific to a set of kernel you have mind and will surely be really hard to understand for someone that doesn't understand both the compiler and low level isa. (and tbh I read this a bunch of times and can't understand at all what that means). Maybe it is that the term are confusing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux Right, that's fair point, I've just mechanically treated the rule and agree it's confusing.
Main reason could be, I tried to describe the rule in terms of individual warpId bits as verifier does but users mainly think in terms of the hint mask shape. Here's a new version of the description explains the allowed bit pattern. Please take a look, I'll revise the comments in the code based on this if this makes better sense.

The hint must select a regular pattern of 1 bits (active warps) and 0 bits
(inactive warps). A valid mask is built from power-of-two-sized groups of 1s
placed within a power-of-two-sized block at an aligned offset. The same 1/0
pattern may appear once or repeat uniformly across the mask.
Currently warp_used_hint is not supported for the non-power-of-two numWarps
like 12.

For example, with num_warps = 8, valid masks include:

  • 0b00001111: one contiguous group of four 1s.
  • 0b11110000: one contiguous group of four 1s at a different offset.
  • 0b01010101: single-1 groups repeated every 2 bits.
  • 0b10101010: the same single-1 group pattern at a different offset.
  • 0b00110011: groups of two 1s repeated every 4 bits.

Irregular masks that hand-pick unrelated 1 bits, such as 0b01101001,
0b00011011, or 0b01111000, are rejected. In all valid cases,
K = popcount(warp_used_hint) is a power of two.

For num_warps > 8, the same rule applies, and patterns may also be nested:
a smaller valid pattern can itself be repeated at a larger power-of-two stride.
For example, with num_warps = 16:

  • 0x0505 = 0000_0101_0000_0101
    inner: single-1 groups every 2 bits inside an 8-bit half.
    outer: that 8-bit pattern repeated every 16 bits.
    selects warps {0, 2, 8, 10}.

Copy link
Copy Markdown
Contributor

@lezcano lezcano May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you want to describe it as a linear vector subspace. These are classified by masks 0 <= mask < num_warps. The ones in a mask represent the bits that may move to describe the subspace. For example, the subspace spanned by 0b101 and 8 warps are 0, 1, 4, 5.

All your examples above for 8 warps that have a 1 at the LSB are linear subspaces, so I gather that, if you just care about the linear ones this would be a good representation.

Would these be enough (the num_warps > 8 ones don't fit this pattern, but yeah.

If you also care about the affine subspaces, these are represented by two masks m, a such that m & a == 0. For example, the mask 0b101 admits a = 0x0 and a = 0x2. The second one representing the linear subspace 2, 3, 6, 7.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, current rule restrict the pattern under LL. Once we confirm the user API, I'll revise the documentation for verifier in this way, i.e., mask on the warpID for example 0b101 means fix second bit as zero and allow all other bits free to get allowed warpID, or even simpler for verifier.

WARP_BASES = [(0, 1), (1, 0), (2, 0)]

# 4-warp partial TDM copy: warps 0-3 issue, warps 4-7 are no-ops.
tdm_warp_used_hint = 0b00001111
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, how is a user expected to chose this mask? What harware factors are influencing this and how useful is it to tune this with different masks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual HW instruction for a TDM copy uses a per-warp descriptor. Today we partition the tensor evenly across all warps in the workgroup, but we want to enable other configurations. Examples include a partial copy (only 4 of 8 warps participate) and split copies where half the warps load from one tensor and the other half from another.

One use case is warp-pipelining, where the upper-half and lower-half warp groups run in a pipelined fashion. In that setup we want only the upper half (the leading group) to issue the copy in the current stage, instead of having the lower half issue a copy that belongs to a later stage.

The follow-up PR will combine multiple predicated copies into a single instruction when possible. The combined copy can carry a different descriptor per warp, for example loading A and B in one instruction.

The mask is determined by the use case rather than chosen for performance. For a plain TDM copy, the attribute should be omitted so that all warps participate. In the long run, end users should not need to set this hint manually. Ideally a future compiler pass would derive it when beneficial.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaning towards this shouldn't be exposed in gluon. If the compiler is able to merge the loads of A and B, then it could also evenly divide the merged load across warps. I don't see the point of the user deciding the exact warp mapping, unless there is some performance reason to.

For the warp-pipelined case, it sounds like there's only one correct way to handle it and the compiler could set the mask from the warp-pipelining pass. I'm also curious if only one warp group is able to handle the load, then how are load and async_copy handled? Do they have layouts that only use 1 warp group?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What this PR does is

  1. user to define warp mapping by predication
  2. compiler to get the warp number actively participating the copy
  3. compiler to rewrite the layout from the original layout -> performed by active warps only

Technically, user doesn't need to know the optimal mapping and compiler can figure that out.
But I'm still not sure we know the optimal number of warps for any given tdm copy use case.

Another option might be still using the explicit hint but just to specify the number of warps that participate the copy instead of asking user to specify exactly which warps to perform. What do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'm still not sure we know the optimal number of warps for any given tdm copy use case.

What are the factors that make fewer warps better? Assuming you're doing a single load, that can't be merged with anything else and isn't warp-pipelined. Or is it more that you want the user to be able to control if loads are fused into a single instruction?

I understand what this PR does mechanically, but want to understand how a user is expected to use this. Currently it sounds like it is more of an implementation detail.

Also you never answered about async_copy in warp-pipelined kernels.

I'm also curious if only one warp group is able to handle the load, then how are load and async_copy handled? Do they have layouts that only use 1 warp group?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you're doing a single load, that can't be merged with anything else and isn't warp-pipelined.

One case is, even though we don't explicitly use warp-pipeline, when we have more than 1 warps per SIMD (e.g. numWarps=8), execution of two warps are serialized on a SIMD. We might expect the same benefit as warp-pipeline case from performing tdm copy on a first warps per SIMD.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what this PR does mechanically, but want to understand how a user is expected to use this. Currently it sounds like it is more of an implementation detail.

Yeah, that's fair. I'm open to revise the user interface let's discuss more detail.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it more that you want the user to be able to control if loads are fused into a single instruction?

And yes, this is currently the main reason.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant async_copy.global_load_to_shared so a non-tdm async copy. In that case you have to pass a tensor of pointers, so the layout of the tensor should control where the copy is issued from. How does that work in a warp-pipelined kernel?

But yeah, for tdm if there are more warps than there are SIMD units we could clamp the number of warps used to the SIMD unit size in the lowering. If that's always the best thing to do, then there's no need to giving the user control over it.

I wonder is the tdm unit per-warp so the code must be careful to balance the load between warps? That could be a good justification for hints, as only the user knows if say one load has structurally higher latency for example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, OK. Haven't really considered to do the same using the async_copy.
I'll come back with a better answer for the last question.

Comment thread third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py Outdated
# Conflicts:
#	python/triton/experimental/gluon/language/amd/gfx1250/tdm.py
#	third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp
#	third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h
#	third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple more nits LG otherwise

Comment thread third_party/amd/python/test/test_tdm_copy.py Outdated
)

expected = a_cpu + b_cpu
torch.testing.assert_close(c.cpu(), expected, atol=1e-3, rtol=1e-3)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's a bit weird to have precision tolerance there, maybe use integers for a more robust test

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, moved to int

@antiagainst antiagainst merged commit 11459af into triton-lang:main May 13, 2026
16 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants