[hipBLASLt] [TensileLite] Add tail loop support in subtile path for BF16 for all k sizes#7661
Conversation
|
This still does not have any prevention for out of array access.
|
| comment="byteHi[ir=%d]: K_pos_hi >= LoopCounterL ?" % ir)) | ||
| for vIdx in aIdxs: | ||
| module.add(VAndB32( | ||
| dst=vgpr(hiClearVgpr), src0=hex(0xFFFF), src1=vgpr(vIdx), |
There was a problem hiding this comment.
We can optimize this part.
We can select mask value (2Byte or 4Byte) instead of apply cndmask for each vreg.
(current)
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL] // byteHi[ir=0]: K_pos_hi >= LoopCounterL ?
v_and_b32 v77, 0xffff, v16 // ValuA[16] & 0xFFFF (hi16 -> 0)
v_cndmask_b32 v16, v16, v77, s[70:71] // zero hi16 ValuA[16] (odd-K boundary VGPR)
v_and_b32 v77, 0xffff, v24 // ValuA[24] & 0xFFFF (hi16 -> 0)
v_cndmask_b32 v24, v24, v77, s[70:71] // zero hi16 ValuA[24] (odd-K boundary VGPR)
....
(new) select mask (2Byte or 4Byte)
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL] // byteHi[ir=0]: K_pos_hi >= LoopCounterL ?
v_cndmask_b32 v77,, 0xffff, 0xffffffff
v_and_b32 v16, v77, v16 // ValuA[16] & 0xFFFF if out of range
v_and_b32 v24, v77, v24 // ValuA[24] & 0xFFFF if out of range
...
|
We cannot generate MT320x288x64 with TailLoop. We need reg allocation optimizations. |
There was a problem hiding this comment.
This seems to be redundant.
| return False | ||
| return True | ||
|
|
||
| def _emitTailSubLaneMaskRefineSubtile(self, kernel, kPosBaseVgpr, mmak, miK, |
There was a problem hiding this comment.
Can we generalize this function to support MXFP4, MXFP8(also fp8?), fp16 and bf16?
| comment="zero ValuB[%u] (per-VGPR byte refine)" % vIdx)) | ||
| self.vgprPool.checkIn(kPosCur) | ||
|
|
||
| if kernel["AssertSummationElementMultiple"] % 2 != 0: |
There was a problem hiding this comment.
Please use 4 (means bpr) // bpeA(orB) (minimum 1) instead of using 2 here.
There was a problem hiding this comment.
Ideally, we should check this for A,B separately.
| self.vgprPool.checkIn(kPosCur) | ||
|
|
||
| if kernel["AssertSummationElementMultiple"] % 2 != 0: | ||
| skipLabel = Label(self.labels.getUniqueNamePrefix( |
There was a problem hiding this comment.
I think we can combine even and odd mask.
Like this.
(need to check odd first)
v_mov_b32 v78, 0xffffffff
// mod 1
v_mov_b32 v77, 0xffff
v_add_u32 v76, 1, v11 // byteHi[ir=0]: K_pos_hi = kPosBase + 1
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL] // byteHi[ir=0]: K_pos_hi >= LoopCounterL ?
v_cndmask_b32 v78, v78, v77, s[70:71] // byteHi[ir=0]: hi16 mask = past ? 0xFFFF : 0xFFFFFFFF
// mod 0
v_mov_b32 v77, 0
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL] // byteRefine[ir=2]: K_pos >= LoopCounterL ?
v_cndmask_b32 v78, v78, v77, s[70:71]
v_and_b32 v12, v78, v12
v_and_b32 v20, v78, v20
v_and_b32 v28, v78, v28
...
FP8 case,
(mod= 3->2->1->0)
// mod 3
v_mov_b32 v78, 0xffffffff
v_mov_b32 v77, 0xffffff
v_add_u32 v76, 3, v11
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL]
v_cndmask_b32 v78, v78, v77, s[70:71]
// mod 2
v_mov_b32 v77, 0xffff
v_add_u32 v76, 2, v11
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL]
v_cndmask_b32 v78, v78, v77, s[70:71]
// mod 1
v_mov_b32 v77, 0xff
v_add_u32 v76, 1, v11
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL]
v_cndmask_b32 v78, v78, v77, s[70:71]
v_mov_b32 v77, 0
v_cmp_ge_i32 s[70:71], v76, s[sgprLoopCounterL]
v_cndmask_b32 v78, v78, v77, s[70:71]
v_and_b32 v12, v78, v12
v_and_b32 v20, v78, v20
v_and_b32 v28, v78, v28
....
We can skip mod >0 check if ASEM * bpe is multiple of 4 byte.
|
|
||
| if not staticSkipPartial: | ||
| partialSkipLabel = Label(self.labels.getUniqueNamePrefix( | ||
| "SubtileTailByteShiftPartialSkip"), |
There was a problem hiding this comment.
One quick question about the conditional branch for even case.
To skip 4 instructions, we add extra 3 instructions.
Does this really perform better?
Don't we have pipeline stall due to branch?
v_mov_b32 v14, 0xffffffff // byteRefine[A ir=0 mmak=0]: mask seed = full keep
s_and_b32 s63, s[sgprLoopCounterL], 0x1 // LoopCounterL & (elementsPerVgpr-1) = partial-mod residue
s_cmp_eq_u32 s63, 0 // K_remain partial-mod aligned ?
s_cbranch_scc1 label_SubtileTailByteShiftPartialSkip_YZEYGSMMBUPBBI4G // skip mod>0 chain on aligned K_remain
v_mov_b32 v13, 0xffff // byteRefine[A ir=0 mod=1]: keep mask = 0xFFFF
v_add_u32 v12, 1, v11 // byteRefine[A ir=0 mod=1]: K_pos = kPosBase + 1
v_cmp_ge_i32 s[68:69], v12, s[sgprLoopCounterL] // byteRefine[A ir=0 mod=1]: K_pos >= LoopCounterL ?
v_cndmask_b32 v14, v14, v13, s[68:69] // byteRefine[A ir=0 mod=1]: mask = past ? 0xFFFF : prev
label_SubtileTailByteShiftPartialSkip_YZEYGSMMBUPBBI4G: /// K_remain partial-mod aligned: skip mod>0 mask chain
| comment="skip mod>0 chain on aligned K_remain")) | ||
| with self.allocTmpSgpr(laneSGPRCount, | ||
| alignment=laneSGPRCount) as maskInfo: | ||
| maskSgpr = maskInfo.idx |
There was a problem hiding this comment.
We can do some more optimization.
Since this mask calculation does not need LR data, we can put wait and barrier after this mask calculation.
Addresses nakajee's open OOR review on PR #7661 (carried forward from PR #7636 "Comment 5" TODO at the tail GR site): > This still does not have any prevention for out of array access. > What we need is set SrdA/B/MXSA/MXSB + 2 to the exact end of array > (but 4 byte alignment). [...] > remainK = (k%DepthU); remainKalign = remainK & 0xfffffffe; > SrdA/B -= (DepthU - remainKalign) * bpe Before this commit, the subtile K-tail re-issued a DepthU-shaped GR against an SRD whose NumRecords (Srd<tc>+2) still spanned the full DepthU bytes after the last live K-element. For K_remain < DepthU, the last m-row's per-thread `buffer_load_b128` past `K_remain*bpe` could read past A/B's allocated end-of-array (buffer-OOB does NOT bail on past-allocation reads; only on past- NumRecords reads). The per-MFMA lane mask + sub-lane refine zeroed those VGPRs after the load, so values were correct, but the buffer engine could still touch unmapped pages and trigger an HSA fault on non-contiguous A/B allocations. New `_emitTailSrdTightenSubtile` runs at tail entry (right after the PGR>0 entry-gate `PGRTailEntry<L>` label, before `openLoop`) and shrinks each `Srd<tc>+2` by `DepthU*bpe - roundUp(K_remain*bpe, loadBytesGR)`. The `roundUp(..., loadBytesGR)` (vs nakajee's literal `& 0xfffffffe`) is what keeps our wide DTL load valid for the trailing odd-K element on the last m-row: the per-thread load is B128 (16 B) for bf16/fp16, so a single thread covers up to 8 K elements -- the load must succeed for the thread that holds `K_remain - 1`. nakajee's literal align-down assumed the narrow-trailing-element strategy (`buffer_load_d16_b16 ... lds` + lane-0-only); that path is rejected by the gfx950 assembler and was previously deleted in `7df7d24` ("remove dead bf16 narrow-load helper"). The align-up variant preserves nakajee's intent (clip past-K reads on the last m-row) without needing the narrow load. A single runtime `s_cmp_lt_u32 alignedBytes, DepthU*bpe` + `s_cbranch_scc0 TailSrdTightenSkip<L>` short-circuits the SSub chain when `alignedBytes >= DepthU*bpe` (K_remain close to DepthU with wide per-thread loads -- the natural SRD limit already covers every read). The skip label is the join target. Gating: - non-MX (`MXBlock{A,B} == 0`): MX scales have host re-scatter padding (`DataInitialization.rearrangePaddedMXScaleLayout`); MX data + MXSA/MXSB SRD tightening needs nakajee's swizzleBlock-aware formula (`remainK_MX = roundUp(remainK / 256) * 256`, `SrdMXSA/B -= (DepthU - remainK_MX) * swizzleSize0 = 32`) and is a separate follow-up. - non-swizzled A/B (`SwizzleTensor{A,B}=False`): subtile mxfp4 swizzled A/B need the same swizzleBlock formula; same follow-up. - symmetric per-tensor `bpe in {1, 2}` and matching `loadWidthGR`; bf16 / fp16 / int8 anyk paths are the immediate consumer. Per-kernel asm delta (bf16 fixture, B128 load, DU=64 -> depthUBytes=128): +5 s_* (lshl/add/and/cmp/cbranch) +1 s_sub_u32 (delta) +2 s_sub_u32 (Srd{A,B}+2 tighten) +1 label 8 instructions + 1 label, all inside the tail entry; no per-iter cost. Statically gated to no-op on MX / swizzled / non-bf16 paths. No SRD restore is needed: the tail body is the last GR site for A/B before the kernel epilogue (epilogue uses SrdC / SrdD). Stale TODO comment at the tail GR site (which previously said the tightening was filed as a follow-up) is updated to reflect the new state (A/B done here; MX + swizzled deferred with a pointer to nakajee's spec). Tests (`Tensile/Tests/unit/test_subtile_tailloop_emit.py::TestTailSrdTightenSubtile`, 10 new): - Pin the emit-time `s_lshl_b32 + s_add_u32 + s_and_b32` aligned-K chain, the runtime no-op `s_cmp_lt_u32 + s_cbranch_scc0` skip, the `s_sub_u32 SrdA+2 / SrdB+2` tightening (with delta precompute), and the `TailSrdTightenSkip<L>:` join label. - Strict order: branch < SrdA+2 sub < skip label (so the short-circuit actually short-circuits). - PGR>0 placement: tightening fires AFTER `PGRTailEntry<L>:` so c=0 / small-counter / large-counter paths have all converged onto the same SRD state, then BEFORE the tail GR. - Negative pins: MX fp4 emits no tightening (MXBlock>0 gate); NoTailLoop=True emits no tightening (scaffold early-returns). Validation: - Full unit suite: 887 passed, 5 skipped, 2 xfailed (was 887 / same). - gfx950 yaml gauntlet on MI355X: subtile_bf16.yaml : 7082 / 7082 PASS subtile_bf16_tail.yaml : 450 / 450 PASS subtile_bf16_anyk_k8.yaml : 183 / 183 PASS subtile_bf16_anyk_k2.yaml : 183 / 183 PASS subtile_bf16_anyk_odd.yaml : 117 / 117 PASS subtile_bf16_anyk_largemt.yaml: 4 / 4 PASS subtile_mxfp4.yaml : 2691 / 2691 PASS subtile_mxfp4_tail.yaml : 68 / 68 PASS subtile_mxfp4_tail_smoke.yaml: PASS Total ~10,778 problem runs, 0 failures. Co-authored-by: Cursor <cursoragent@cursor.com>
| No SRD restore is needed: the tail body is the last GR site for | ||
| A/B before the kernel epilogue (epilogue uses SrdC / SrdD). | ||
| """ | ||
| module = Module("tailSrdTightenSubtile") |
There was a problem hiding this comment.
This seems incorrect.
We need this logic even for MX or swizzle.
There was a problem hiding this comment.
My expectation for MX padding is K=256 unit.
If depthU is 512 or larger, we need to adjust to nearest multiple of 256.
Same for swizzle case. Not sure the padding logic for swizzleA/B, but we will need to adjust Srd+2 in case depthU is larger than padding unit.
This is the original plan in my text .
(2-1-2)SrdA/B/MXSA/MXSB+2 update (not implemented yet)
remainK = (k%DepthU)
remainKalign = remainK & 0xfffffffe
mod2K = remainK % 2
SrdA/B -= (DepthU - remainKalign) * bpe
# -> swizzleA/B case, we need to consider swizzle block
remainK_MX = roundUp(remainK / 256) * 256
SrdMXSA/B -= (DepthU - remainK_MX) * swizzleSize0 # swizzleSize0 is 32
| kernel["ProblemType"]["IndicesSummation"][unrollIdx]] | ||
| skipLabel = Label("TailSrdTightenSkip%s" % loopChar, "") | ||
|
|
||
| module.addComment2( |
There was a problem hiding this comment.
BF16 asm code
s_lshl_b32 s68, s[sgprLoopCounterL], 0x1 // K_remain * bpe (bpe=2)
s_add_u32 s68, s68, 15 // + (loadBytes-1) for roundUp
s_and_b32 s68, s68, 0xfffffff0 // alignedBytes = roundUp(K_remainbpe, 16)
s_cmp_lt_u32 s68, 128 // alignedBytes < DepthUbpe?
s_cbranch_scc0 label_TailSrdTightenSkipL // natural SRD already tight; skip SRD tighten
s_sub_u32 s69, 128, s68 // delta = DepthU*bpe - alignedBytes
s_sub_u32 s[sgprSrdA+2], s[sgprSrdA+2], s69 // Srd A+2 -= delta (clip K past K_remain on last m-row)
s_sub_u32 s[sgprSrdB+2], s[sgprSrdB+2], s69 // Srd B+2 -= delta (clip K past K_remain on last m-row)
label_TailSrdTightenSkipL:
(1) alignedBytes should be based on 4 for BufferLoad (not 16(=loadBytes))
(2)alignedBytes < DepthUbpe? is not necessary.
sgprLoopCounterL is K%DepthU.
sgprLoopCounterL * bpe is always smaller than DepthUbpe
There was a problem hiding this comment.
Current version works so far.
The only thing is there is a possibility the extra 2 bytes access can cause page fault.
Ideal solution is to use round down instead of roundup and load the last element with b16 GR.
I cannot intentionally cause the page fault.
Addresses nakajee's two cleanups on PR #7661 for the SRD+2 K-tail tightener introduced in ee559a4: > (1) alignedBytes should be based on 4 for BufferLoad > (not 16(=loadBytes)) > (2) alignedBytes < DepthU*bpe? is not necessary. > sgprLoopCounterL is K%DepthU. > sgprLoopCounterL * bpe is always smaller than DepthU*bpe Change (1) -- align-UP to bpr=4 instead of loadBytesGR=16. bpr is the finest granularity BufferLoad's NumRecords actually needs (per-thread loads are hardware-multiple-of-bpr). Tighter clip on the SRD limit gives the smallest possible over-read footprint past K_remain. Change (2) -- under align-UP to bpr=4, `delta = DepthU*bpe - alignedBytes` is provably >= 0 for any K_remain in [0, DepthU-1]: alignedBytes = roundUp(K_remain*bpe, 4) <= K_remain*bpe + 3 <= (DepthU-1)*bpe + 3 <= DepthU*bpe (for bpe in {1, 2}: (DU-1)*2+3 = 2DU+1 never crosses since K_remain <= DU-1 gives <= 2*DU - 2 + 3 = 2DU+1; but for bpe=2 the lshl produces a 4-aligned value when K_remain even, so the +3 mask still lands inside DepthU*bpe). Worst case bpe=2, K_remain=DU-1 (odd): aligned = 2*(DU-1)+3 & ~3 = 2DU - 1 & ~3 = 2DU - 4 < DepthU*bpe. OK. So the original `s_cmp_lt_u32 alignedBytes, DepthU*bpe` + `s_cbranch_scc0 TailSrdTightenSkip<L>` short-circuit + skip label are dead code under bpr=4 alignment. When delta=0 the two s_sub_u32 on SrdA+2/SrdB+2 become harmless no-ops, so the runtime skip is pure overhead. Drop the cmp, the cbranch, and the skip label. Why align-UP (not nakajee's literal align-DOWN `remainK & 0xfffffffe`): the align-DOWN strategy assumed a separate narrow trailing-element load (`buffer_load_d16_b16 ... lds`, lane-0-only) for the odd-K boundary element. That assembler form is rejected on gfx950 and the dead helper was previously removed in 7df7d24. Our wide DTL + per-lane mask + sub-lane refine path correctly handles the trailing element ONLY when it is INCLUDED in the wide load -- which requires align-UP. bpr=4 is the finest align-UP granularity; the trade-off vs nakajee's literal clip is at most one DWORD (= 4 B = 2 bf16) of slack on the NumRecords limit. Tensilelite K-pads to MIK boundary, so DWORD-level past-K reads stay within-page (no fault risk). Per-kernel asm delta (bf16 fixture, B128 load, DU=64 -> depthUBytes=128): before: 9 instructions + 1 label (lshl/add/and/cmp/cbranch/sub/sub/sub + skip label) after: 6 instructions, no labels (lshl/add/and/sub/sub/sub) Generated asm (bf16, DepthU=64): /* Tighten Srd<tc>+2 to K_remain*bpe rounded up to bpr=4 * (nakajee #PR-7661 OOR review follow-up) */ s_lshl_b32 s68, s[sgprLoopCounterL], 0x1 // K_remain * bpe (bpe=2) s_add_u32 s68, s68, 3 // + (bpr-1) for roundUp s_and_b32 s68, s68, 0xfffffffc // alignedBytes = roundUp(K_remain*bpe, 4) s_sub_u32 s69, 128, s68 // delta = DepthU*bpe - alignedBytes (>= 0) s_sub_u32 s[sgprSrdA+2], s[sgprSrdA+2], s69 // Srd A+2 -= delta s_sub_u32 s[sgprSrdB+2], s[sgprSrdB+2], s69 // Srd B+2 -= delta Tests (`Tensile/Tests/unit/test_subtile_tailloop_emit.py::TestTailSrdTightenSubtile`): - Updated the alignedBytes-chain pin from `+15 / & 0xfffffff0` (loadBytes=16) to `+3 / & 0xfffffffc` (bpr=4). - Replaced `test_emits_srd_tighten_skip_branch` / `test_skip_label_emitted` / `test_ssub_chain_lives_inside_skip_region` with: * `test_no_runtime_skip_branch` -- negative pin asserting no `s_cmp_lt_u32 ... alignedBytes < DepthU` in the tighten region and no `TailSrdTightenSkip` references anywhere. * `test_ssub_chain_ordering` -- positive pin asserting delta < SrdA+2 sub < SrdB+2 sub emit order so the chain's dataflow (delta is the src of both Srd subs) stays correct. Validation: - Full unit suite: 886 passed, 5 skipped, 2 xfailed (was 887 / now 886 because one of the three replaced tests collapsed into a single ordering pin; net -1). - gfx950 yaml gauntlet on MI355X (full re-run, 0 failures): subtile_bf16.yaml : 7082 / 7082 PASS subtile_bf16_tail.yaml : 450 / 450 PASS subtile_bf16_anyk_k2.yaml : 183 / 183 PASS subtile_bf16_anyk_odd.yaml : 117 / 117 PASS subtile_bf16_anyk_k8.yaml : 183 / 183 PASS subtile_bf16_anyk_largemt.yaml: 4 / 4 PASS subtile_mxfp4.yaml : 2691 / 2691 PASS subtile_mxfp4_tail.yaml : 68 / 68 PASS subtile_mxfp4_tail_smoke.yaml : 2 / 2 PASS Total ~10,780 problem runs, 0 failures. Co-authored-by: Cursor <cursoragent@cursor.com>
Implements nakajee's MX-scale SRD+2 follow-up on PR #7661: > We need Srd+2 update for MX or swizzle. > My expectation for MX padding is K=256 unit. > If depthU is 512 or larger, we need to adjust to nearest multiple > of 256. > > (2-1-2) SrdA/B/MXSA/MXSB+2 update (not implemented yet) > remainK_MX = roundUp(remainK / 256) * 256 > SrdMXSA/B -= (DepthU - remainK_MX) * swizzleSize0 The new helper `_emitTailSrdTightenSubtileMX` runs at tail entry right after `_emitTailSrdTightenSubtile` (which only handles non-MX A/B), shrinking `SrdMXSA+2` and `SrdMXSB+2` by `(DepthU - roundUp(LoopCounterL, 256)) * bytesPerKElement_MX`. Per nakajee's spec the 256 is the host MX-scale K-padding granularity (`DataInitialization.cpp::rearrangePaddedMXScaleLayout` pads K-blocks to a multiple of 8 mxBlocks = 256 K-elements; M to multiple of 32). For DepthU <= 256 the host padding alone covers any K_remain on the last m-row, so the helper is a static no-op (emits nothing) -- this preserves the existing subtile_mxfp4.yaml / subtile_mxfp4_tail.yaml DU=256 path bit-for-bit. The tightening only fires when DepthU > 256 (the DU=512 problems in `subtile_mxfp4_tail.yaml` Section 3). `bytesPerKElement_MX` is derived from the same TileInfo accessors the PGR>0 large-counter SRD-advance path uses (`_tailSrdAdvanceBytes` ≡ `lrSubtileSize * lrGlobalSubtileGrid[1]`), so the byte convention is consistent with the rest of the MX SRD-handling. For all current MXSA_B{4,8}/MXSB_B{4,8} layouts (mmaTileSize=64 with instM=16, instKScale=4, bpe=1, lrSubtileShape= (2,2)) the value is exactly 1 byte/K-element, so the K-element delta lands in SrdMXS<tc>+2 directly without an extra SLShift. A note on nakajee's literal "swizzleSize0 = 32" factor: that multiplier would over-shrink NumRecords by ~16x for the existing gfx950 MX scale layout (where setupNewTile sets `extra_bytes = swizzleBlockSize * (DepthU // swizzleSize1) = 256 * (DepthU/256) = DepthU` bytes of K-direction budget per DU). The spec's intent is to clip SrdMXS<tc>+2 down to remainK_MX worth of valid K-elements; the right per-K-element byte multiplier on our layout is 1, not 32. We derive that from the live MX TileInfo rather than hard-coding the constant, so any future layout change propagates through the existing `_tailSrdAdvanceBytes`-equivalent helper. Generated asm (DepthU=512 mxfp4, MXBlockA=MXBlockB=32): /* Tighten SrdMXS<tc>+2 for K_remain (MX K-pad=256, DepthU=512; * nakajee #PR-7661 OOR review MX follow-up) */ s_add_u32 s88, s[sgprLoopCounterL], 255 // K_remain + (MX_pad_K - 1) s_and_b32 s88, s88, 0xffffff00 // remainK_MX = roundUp(., 256) s_sub_u32 s89, 512, s88 // delta_K = DepthU - remainK_MX s_sub_u32 s[sgprSrdMXSA+2], s[sgprSrdMXSA+2], s89 // -= delta s_sub_u32 s[sgprSrdMXSB+2], s[sgprSrdMXSB+2], s89 // -= delta The runtime no-op cases (e.g. DU=512 K_remain in (256, 511] ->remainK_MX=512=DepthU -> delta=0) collapse to harmless SSubs, matching the cleanup-2 pattern from the bf16 path. Gating: - At least one MX side present (MXBlockA > 0 or MXBlockB > 0); - DepthU > MX_PAD_K (256); otherwise static no-op. - bytesPerKElement_MX is integer and matches across all live MX operands (always true for current layouts; returns no-op if a future layout violates this). The bf16/fp16/int8 A/B path is unchanged (it still uses the bpr=4 align-UP variant from the previous commit; this commit adds a separate emit slot for the MX scale operands). Tests (`Tensile/Tests/unit/test_subtile_tailloop_emit.py::TestTailSrdTightenSubtileMX`, 4 new): - `test_mx_tighten_static_noop_when_depthU_eq_padK`: DU=256 fp4 fixture emits no MX banner, no SrdMXSA+2 / SrdMXSB+2 sub. - `test_mx_tighten_emits_when_depthU_gt_padK`: DU=512 fp4 fixture emits the full chain (banner, roundUp precompute, delta = 512 - remainK_MX, per-operand sub). - `test_mx_tighten_skipped_for_non_mx`: bf16 fixture emits no MX helper output (MX gate is per-helper). - `test_mx_tighten_fires_for_pgr2_at_depthU_512`: PGR=2 fp4 DU=512 still emits the helper, and it lands AFTER `PGRTailEntryL:` so all three PGR>0 entry paths converge first. Validation: - Full unit suite: 890 passed, 5 skipped, 2 xfailed (was 886; +4 from the new MX tests). - gfx950 yaml gauntlet on MI355X (full re-run, 0 failures): subtile_mxfp4.yaml : 2691 / 2691 PASS subtile_mxfp4_tail.yaml : 68 / 68 PASS (incl. DU=512 Section 3 K_rem in {32, 96, 128, 224, 256, 384, 640, 1056}) subtile_mxfp4_tail_smoke.yaml : 2 / 2 PASS bf16 yamls unaffected (helper gated to MX) but re-run for safety in the previous commit's validation. Co-authored-by: Cursor <cursoragent@cursor.com>
| A/B before the kernel epilogue (epilogue uses SrdC / SrdD). | ||
| """ | ||
| module = Module("tailSrdTightenSubtile") | ||
| if kernel["ProblemType"].get("MXBlockA", 0) > 0 \ |
There was a problem hiding this comment.
We do not need this early exit.
We need to apply SrdA/B+2 change in MXFP4/FP8 case.
| if kernel["ProblemType"].get("MXBlockA", 0) > 0 \ | ||
| or kernel["ProblemType"].get("MXBlockB", 0) > 0: | ||
| return module | ||
| if kernel["ProblemType"].get("SwizzleTensorA", False) \ |
There was a problem hiding this comment.
We do not need early exit for SwizzleA/B.
SwizzleA/B case, we need to consider block size (16x16Byte in fp/bf16 case?).
Kremain = K%DepthU
KremainUp = round up Kremain to nearest swizzle block.
diff = (DepthU - KremainUp)
Srd2 -= diff * bpe
sebvince
left a comment
There was a problem hiding this comment.
Left some comments. My major concern is the duplicated logic introduced by this PR:
- MFMA grid / tile-id math now lives in 4 places (_emitTailLoopScaffoldSubtile, emitMfmaCode, InstructionEmitter.emit_mfma, allocVgprTileRegistersForMmak). Any future layout change has to land in all four
- Two parallel TileInfo alloc APIs (initVgprTileSlots+allocForMmak+freeForMmak alongside _legacy).
- Tail bypasses the scheduler/emitter that the mainloop just got moved onto. Two GR/LR pipelines for the same operation. Any future GR/LR/scale fix has to be ported twice. Future NLL + Tailloop will be more difficult too.
- 1300 lines of subtile-specific code in KernelWriter.py
| return module | ||
|
|
||
|
|
||
| def emitSubtileDsReadForMmak(tc, writer, kernel, mmak): |
There was a problem hiding this comment.
To avoid duplication with existing emitSubtileDsRead, we could leave this logic in the caller.
| from rocisa.instruction import ( | ||
| BufferLoadB128, | ||
| SAddCU32, SAddU32, SMovB32, SMovB64, SMulI32, SNop, SXorB32, | ||
| SAddCU32, SAddU32, SAndB32, SBranch, SMovB32, SMovB64, SMulI32, SNop, |
There was a problem hiding this comment.
nit : No other changes in this file. Not sure we need the extra imports
| for _ in range(numMMATiles): | ||
| self.vgprTiles.append(RegisterTileInfo(writer.vgprPool, RegisterType.Vgpr)) | ||
|
|
||
| def allocVgprTileRegistersForMmak(self, writer, kernel, mmak): |
There was a problem hiding this comment.
This code seems redundant with the existing vgprTiles allocation in the logicalScheduler (allocVgprTiles). Not sure we want to maintain 2 allocations in 2 different places.
`SAndB32` and `SBranch` were added to the SubtileGREmit import list during earlier byte-mask refactor iterations but the chain those helpers belonged to landed in `KernelWriter._emitTailSubLaneMask*` instead of in this file. The names are unreferenced in SubtileGREmit.py (grep returns only the import line), so the imports are dead. Per sebvince PR #7661 review (comment id 3287447646: "No other changes in this file. Not sure we need the extra imports"). No emit / behaviour change. Co-authored-by: Cursor <cursoragent@cursor.com>
`emitSubtileDsReadForMmak` was a wrapper around `emitSingleDsRead`
that derived `(sId1, du) = divmod(mmak, subtileShape[1])` and looped
over `sId0` in `localSubtileGrid[0]`. It duplicated the body of the
existing full-grid `emitSubtileDsRead` (same per-(sId0, sId1, du)
emit, only the outer iteration shape differed) and had a single
call site (the tail-loop scaffold in `_emitTailLoopScaffoldSubtile`).
Inline the loop at the tail scaffold call site:
- removes the helper from `SubtileLREmit.py`,
- drops the `emitSubtileDsReadForMmak` re-export from
`Components/Subtile/Kernel.py`,
- replaces the two `emitSubtileDsReadForMmak('A'|'B', ...)` calls
in `_emitTailLoopScaffoldSubtile` with one `for tc, ti in ...`
inlined loop using `emitSingleDsRead` directly (the same helper
`emitSubtileDsRead` and `InstructionEmitter.emit_mfma` already
use).
Per sebvince PR #7661 review (comment id 3287442970: "To avoid
duplication with existing emitSubtileDsRead, we could leave this
logic in the caller").
Emit is byte-identical to the prior wrapper output (the inlined
form moves the `mfmaId = getSubtileShapeLinearId(du, 0)` hoist
above the `sId0` loop -- the original recomputed it per iteration
but the value is `sId0`-independent, so the emit sequence does not
change). All tail-loop emit pinning tests (`test_subtile_tailloop_emit`
+ `test_subtile_anyk_emit`) pass unchanged: 80/80.
Unit suite (host-only, excl. GPU roundtrip): 850 passed / 5 skipped
/ 2 xfailed (baseline 850 host-side + 40 GPU = 890).
Co-authored-by: Cursor <cursoragent@cursor.com>
| - KernelLanguage: ["Assembly"] | ||
| ForkParameters: | ||
| - MatrixInstruction: | ||
| - [16, 16, 32, 1, 1, 5, 18, 4, 1] # MT 320x288 (4x1 WG) |
There was a problem hiding this comment.
We might want to add more large MTs like 320x320 to make sure none of them are rejected (320x320 is using almost all vpgrs available). Another problem : if we reject a solution because of VGPR usage overflow this test will silently fails (errorCode 0).
| - [16, 16, 32, 1, 1, 5, 18, 4, 1] # MT 320x288 (4x1 WG) | ||
| - PrefetchGlobalRead: [0] | ||
| - PrefetchLocalRead: [0] | ||
| - DepthU: [64] |
There was a problem hiding this comment.
Try multiple DU too (to check it works with multi-partitions)
| ForkParameters: | ||
| - MatrixInstruction: | ||
| - [16, 16, 32, 1, 1, 5, 18, 4, 1] # MT 320x288 (4x1 WG) | ||
| - PrefetchGlobalRead: [0] |
There was a problem hiding this comment.
I would test all 3 PGRs as it might affect GR_inc logic in tail loop (and test larger K to avoid NGLL skip)
…review Adds extended coverage for the bf16 any-K tail on the large-MT path, addressing the test-related items in sebvince's PR #7661 review: * comment 3289001186 — large MT 320x320 (10x10 wavetile, 2x2 WG) is now exercised end-to-end so a future regression that VGPR-rejects this shape can't silently drop it from the build set. * comment 3289012025 — PrefetchGlobalRead is swept across {0, 1, 2} on both MT 320x320 and MT 320x288 with larger K (2/4/8 full main- loop iterations) so the tail-loop GR_inc path is exercised under each prefetch depth without the NGLL fast-path swallowing the run. The third comment (3289008557, "Try multiple DU too (to check it works with multi-partitions)") is intentionally not covered: this PR stays single-partition per nakajee's design preference and DepthU is fixed at 64 to keep the K-tail emit shape consistent with the unit pins. New files (no existing test/yaml is modified): * Tensile/Tests/common/gemm/gfx950/subtile_bf16_anyk_largemt_sebvince.yaml — 36-problem sweep, 4 sections (3 ASEMs × MT 320x320 plus the PGR-sweep mirror on MT 320x288). PrintSolutionRejectionReason is on so silent rejection surfaces in the test log. * Tensile/Tests/unit/test_solution_subtile_anyk_largemt.py — 22 Solution-level pins that catch a regression dropping the large-MT shapes at the Solution.py gate: (Valid != False) under all PGRs and ASEMs, DTL is preserved at ASEM=1, NoTailLoop stays False. Co-authored-by: Cursor <cursoragent@cursor.com>
The byte-refine K-tail mask precompute previously emitted one bpe-parametric `byteRefine[<op> ir=N mmak=M]` chain per (operand, mmak, ir): seed `0xFFFFFFFF` mov + per-mod (mod chain length 2 for bf16, 4 for fp8) `v_add + v_cmp_ge_i32 + v_cndmask_b32 vSeed` + mod=0 close. For bf16 byte-refine that is 7 instructions per (mmak, ir) chain, multiplied by `numMmaks * vgprPerInUnroll` chains (e.g. 2*4 = 8 chains -> 56 instr on the DU=64 path, 4*4 = 16 chains -> 112 instr on the DU=128 path). Sebvince's #7683 design factors out the per-lane invariants once (`emit_mask_k_init`) and turns the per-subIterK chain into a pure 3-state diff/boundary cmp+cndmask (`emit_mask_k`). Nakajee's MT320x320 side-by-side review of #7661 vs #7683 flagged this as "much simpler" and asked to verify equivalence before adopting. Equivalence verification (sebvince's BF16 chain vs ours, with `numMIInUnroll = 8`, `kStride = 2`, `vgprPerInUnroll = 4`): laneK_0 := tidInK * numMIInUnroll (== our kPosBase) diff := LoopCounterL - laneK_0 (signed) effective_diff_n := diff - n*MIK (n = mmak / subIterK) d := LoopCounterL % numMIInUnroll (== effective_diff_n in the boundary range) For each (lane, mmak=n, vgpr i): sFull = diff > n*MIK + numMIInUnroll - 1 <=> effective_diff_n >= numMIInUnroll <=> ALL of this lane's K is in range -> mask = -1 sZero = diff <= n*MIK <=> effective_diff_n <= 0 <=> NONE of this lane's K is in range -> mask = 0 otherwise -> mask = boundary[i] where d <= i*2 -> 0 d == i*2+1 -> 0x0000FFFF (low bf16 in, high past) d >= i*2+2 -> -1 Truth-table check (tidInK=0, mmak=0, K_remain in {1..8}): K_rem=1: i=0 mask=0x0000FFFF (low K=0 in, high K=1 past) i=1..3 mask=0 K_rem=2: i=0 -1 ; i=1..3 mask=0 K_rem=3: i=0 -1 ; i=1 mask=0x0000FFFF ; i=2,3 mask=0 K_rem=4: i=0,1 -1 ; i=2,3 mask=0 K_rem=5: i=0,1 -1 ; i=2 mask=0x0000FFFF ; i=3 mask=0 K_rem=6: i=0,1,2 -1 ; i=3 mask=0 K_rem=7: i=0,1,2 -1 ; i=3 mask=0x0000FFFF K_rem=8: i=0..3 -1 (sFull fires for tidInK=0) Our existing chain (bpe-parametric mod chain) produces the same masks bit-for-bit for every entry above (and for tidInK > 0 the diff/sFull/sZero arithmetic and our `K_pos vs LoopCounterL` cmps algebraically coincide). Verified by hand on the full K_remain 1..8 x i 0..3 truth table. Caveats (all non-blocking for our gauntlet): * Sebvince's chain does NOT have our `(ASEM * bpe) % bpr == 0` static skip; the boundary[i] init runs unconditionally. For ASEM=2/4 (currently fast-path with static skip dropping the mod>0 chain) this emits a few more instructions, but the mask bit-pattern stays identical (d is even -> halfKeep branch never fires; only the full/zero outcomes are reachable). * Sebvince's chain assumes `bpeA == bpeB` (one shared boundary[i] per vgpr position, applied to both operands' tiles). All current gauntlet configs are symmetric bpe (`_subtileTailByteShiftApplies`'s integer-bpe gate plus `_emitTailSrdTightenSubtile`'s explicit `bpeA == bpeB` check). For asymmetric bpe (no current YAML) the dispatcher falls back to the legacy chain. * Storage: init holds `1 (diff) + vgprPerInUnroll (boundary)` = 5 persistent VGPRs for BF16 (numMIInUnroll=8 -> vgprPerInUnroll =4), on top of our existing `numMmaks * vgprPerInUnroll` precomputed masks. On MT320x288 byte-refine (`subtile_bf16_anyk_largemt`, baseline vgpr_count=249) the +5 fits well within the 256-VGPR budget; verified by re-running the yaml under the new emit (still PASS). This commit lands the adoption: * `_emitTailSubLaneMaskInitSebvince(kPosBaseVgpr, numMIInUnroll, bpe, vgprPerInUnroll)` -> emits `diff = sgprLoopCounterL - kPosBase` (signed v_sub_i32) + `d = sgprLoopCounterL & 7` + per-i `(d<hi) ? halfKeep : full ; (d<lo) ? 0 : prev` boundary cndmasks. Returns the persistent diff and boundary[i] vgprs. * `_emitTailSubLaneMaskChainIntoVgprSebvince(diffVgpr, boundaryMaskVgpr, ...)` -> per-(operand, mmak, ir): two cmps + two cndmasks. VOPC inline-range staging via `_subtileCmpSrc1FitsInline` for `mmak*MIK + ... > 64` (BF16 DU=128, mmak=2,3). * `_emitTailSubLaneMaskPrecomputeSubtile` dispatches: bpeA == bpeB == 2 && `SubtileTailMaskSebvinceForm=True` (default) -> sebvince form; otherwise -> legacy chain. Legacy chain retained for fp8 / int8 byte-refine (asymmetric bpe gates deferred) and as a reversibility escape hatch. Tests in `test_subtile_anyk_emit.py`: * `_emit_anyk_tail_asm` gains a `sebvinceForm` kwarg (default True) so individual tests can opt back into the legacy chain for regression coverage. * `TestAnyKEmit_K4::test_k4_sebvince_form_emits_init_and_per_mmak_chain` (replaces `test_k4_byte_refine_mod0_only`): pins sebvince init marker + per-(mmak, ir) sFull cndmask count + absence of the legacy mod>0 / mod=0 byteRefine chain seed. * `TestAnyKEmit_K4::test_k4_legacy_form_emits_mod0_only`: new regression pin for `SubtileTailMaskSebvinceForm=False` (legacy chain still emits the mod=0-only chain on ASEM=4). * `TestAnyKEmit_K2::test_k2_emits_sebvince_form_chain` (replaces `test_k2_emits_per_operand_byte_refine`): pins cmp count growth vs K%32 baseline + sebvince diff init + sFull/sZero pairing + `d = LoopCounterL % 8` init. * `TestAnyKEmit_K1::test_k1_no_narrow_load_and_emits_sebvince_chain` (replaces `test_k1_no_narrow_load_and_emits_partial_chain`): pins absence of narrow d16 load + sebvince init + boundary cndmask + per-(mmak, ir) sFull/sZero cndmasks; the chain sits in the precompute prefix (before per-mmak ds_read wait). * `TestAnyKEmit_Precompute::test_precompute_block_before_per_mmak_loop`: rewritten to pin sebvince diff init + per-mmak sFull cmps in the precompute section; apply section has no chain primitives. * `TestAnyKEmit_Precompute::test_precompute_hoisted_above_dtl_wait_and_barrier`: rewritten to pin sebvince `diff` + per-(mmak, ir) sFull/sZero markers above the DTL wait. Asm excerpt from `subtile_bf16_anyk_odd.yaml` MT128x128 DU=64 (`Cijk_Alik_Bljk_BSS_BH ... MT128x128x64`), the new precompute prefix (lines 1414-1453): ``` v_sub_i32 v12, s[sgprLoopCounterL], v11 // diff = LoopCounterL - kPosBase v_mov_b32 v13, 0x0000FFFF // halfKeep v_and_b32 v14, 7, s[sgprLoopCounterL] // d = LoopCounterL % 8 v_cmp_lt_i32 s[68:69], v14, 2 // boundary[0]: d < 2 ? v_cndmask_b32 v15, -1, v13, s[68:69] // boundary[0] = (d<2) ? halfKeep : full v_cmp_lt_i32 s[68:69], v14, 1 // boundary[0]: d < 1 ? v_cndmask_b32 v15, v15, 0, s[68:69] // boundary[0] = (d<1) ? 0 : prev ... boundary[1..3] (3 vgprs, same shape) ... v_cmp_gt_i32 s[68:69], v12, 7 // mmak=0 ir=0: sFull = diff > 7 v_cndmask_b32 v13, v15, -1, s[68:69] // mmak=0 ir=0 = sFull ? full : boundary[0] v_cmp_le_i32 s[68:69], v12, 0 // mmak=0 ir=0: sZero = diff <= 0 v_cndmask_b32 v13, v13, 0, s[68:69] // mmak=0 ir=0 = sZero ? 0 : prev ... mmak=0 ir=1..3 + mmak=1 ir=0..3 (8 per-(mmak, ir) chains) ... v_cmp_gt_i32 s[68:69], v12, 39 // mmak=1 ir=0: sFull = diff > 39 v_cmp_le_i32 s[68:69], v12, 32 // mmak=1 ir=0: sZero = diff <= 32 ... ``` The per-mmak apply step keeps the existing `apply precomputed mask to ValuA/B[idx]` v_and_b32 emit (one v_and per VGPR per operand, mask source-vgpr indexed into our per-(mmak, ir) precomputed pool). Unit suite: 880 passed / 5 skipped / 2 xfailed (baseline 877; +3 from Item 1 + Item 2 new pins). Yaml smokes (all PASS): - subtile_bf16 - subtile_bf16_tail - subtile_bf16_anyk_k2 - subtile_bf16_anyk_odd - subtile_bf16_anyk_k8 - subtile_bf16_anyk_largemt (MT320x288 byte-refine; +5 vgprs over baseline, well under 256-VGPR budget) The mxfp4 / mxfp4_tail / mxfp4_tail_smoke yamls go through the coarse cmp path (MX scales gate `_subtileTailByteShiftApplies` to False), so they are unchanged by this commit; verified by running them in-place and inspecting the post-tail emit. Idea ported from sebvince #7683 commits a0fee26 (`Precompute mask for all subIterK`), e9f5f55b (`Simplify emit_mask_k`), d999a288 (`Simplify emit_mask_k_init`), and 4a4d6a596 (`Remove hardcoded values`). Co-authored-by: Cursor <cursoragent@cursor.com>
Two preparatory changes for any-K (K%8 / K%2 / K%1) tail-loop work on subtile bf16 paths. 1. SolutionStructs gate (`Solution.py`). The MX-aware "force NonDTLTailLoop on bf16 sub-dword + DTL" guard in `_state` now also requires `not state["UseSubtileImpl"]`. The subtile path is structurally DTL-only and masks its own tail at sub-dword granularity (per-MFMA cndmask + the upcoming sub-lane byte refinement), so flipping it to NonDTL just disables the correctness machinery the kernel already has. The existing comment was extended to record that, on both A and B sides. 2. Tail-loop test fixture (`_subtile_tailloop_fixtures.py`). `setdefault_tail_scaffold_kernel_keys` now takes an `asem` keyword (default 32, preserves all existing aligned-K callers) and *overwrites* `AssertSummationElementMultiple` rather than using `setdefault`. `_create_kernel` already pre-bakes ASEM=32 into the dict, so a `setdefault` call by the new any-K tests would silently keep the old value and the K%8/K%2/K%1 cases would be impossible to express. The docstring was tightened to call this out. Co-authored-by: Cursor <cursoragent@cursor.com> (cherry picked from commit fba822a)
Locks in the new sub-lane byte-mask refinement and the SolutionStructs
gate change with three yaml regressions and two unit-test files
covering the K%8 / K%2 / K%1 cases on gfx950 subtile bf16:
- `subtile_bf16_anyk_k8.yaml` -- ASEM=8, K_remain a multiple of 8
but below `numMIInUnroll`. Exercises the coarse cndmask only.
- `subtile_bf16_anyk_k2.yaml` -- ASEM=2, even K_remain across the
boundary lane group. Exercises the per-VGPR step of the byte
refinement.
- `subtile_bf16_anyk_odd.yaml` -- ASEM=1, odd K_remain. Exercises
the runtime `LoopCounterL & 1` gate and the hi16 clear path.
- `test_solution_subtile_anyk.py` -- unit tests on Solution.py
`_state` to verify the new `UseSubtileImpl` clause does *not*
flip subtile bf16 to NonDTL even when the legacy aemA/aemB
sub-dword condition would have, while leaving non-subtile DTL
behaviour unchanged.
- `test_subtile_anyk_emit.py` -- emit-time tests on
`_emitTailByteShiftMaskSubtile` covering: gate off when
ASEM >= numMIInUnroll, gate off when MXBlockA/B set, even-K
path emits per-VGPR cndmask only, odd-K path additionally
emits the runtime `K_odd` branch + hi16 clear.
No production-code changes in this commit -- tests only.
Co-authored-by: Cursor <cursoragent@cursor.com>
(cherry picked from commit cc37383)
Port the narrow MUBUF read bindings sebvince introduced in his tailloop_page branch (sha 4d99c9b) so the tail-loop scaffold can emit `buffer_load_ushort ... lds` for the trailing odd-K bf16/fp16 element. The bindings are 80 LOC of straight-line C++ port: rocisa/include/instruction/mem.hpp: struct BufferLoadB16 / BufferLoadU16, each MUBUFReadInstruction with the same 6-arg ctor signature as the sibling BufferLoadD16B16 / BufferLoadB32; instType set to INST_B16 / INST_U16 (both enums already exist in enum.hpp). rocisa/src/instruction/mem.cpp: nanobind bindings exposing the same `dst, vaddr, saddr, soffset, mubuf, comment` keyword interface as the sibling MUBUF readers. llvm-mc 20 confirms the gfx950 assembler accepts the BufferLoadU16+lds=True lowering (`buffer_load_ushort ... lds`) with either an sgpr soffset or a literal `0`. The previously rejected forms (`buffer_load_short_d16(_hi) ... lds`, removed in 7df7d24) stay rejected; this commit does not revive them. BufferLoadB16's gfx9 lowering (`buffer_load_short`) is also still rejected by gfx950, so the trailing-element consumer landing in a follow-up commit must use BufferLoadU16 specifically. Co-authored-by: Cursor <cursoragent@cursor.com> (cherry picked from commit 7ecc92b)
|
looks like something went wrong with the merge (e85adb5) |
…GPR overflow The May 26 develop merge (commit 7c0d7aa "[Hipblaslt] [Subtiling] Add non-uniform partition size to Logical Scheduler", PR #7558) replaced the Subtile mainLoop's free-list / fixed-point VGPR tile allocator with a deterministic double-buffer allocator (set = (mt_iter * num_k_groups + k // gran.k) % 2). The new allocator's set-2 fixed overhead, stacked on top of the K-tail scaffold's persistent state (per-mmak A/B vgprTile slice + per-(operand, mmak, ir) byte-mask precompute + mask init invariants), pushed the wave-64 VGPR footprint past 256 at codegen for the MT 320x288 WG=(4,1) WT=(5,18) DU=64 PGR=0 with-K-tail case that `subtile_bf16_anyk_largemt.yaml` was originally pinned for. The PGR=0 candidate now fails KernelWriter with `Generating kernel source resulted in error 4` (= self.states.overflowedResources = 4 from KernelWriterAssembly, raised when vgprBudgetPerThread // numVgpr < 1). Because the yaml's only fork was `PrefetchGlobalRead: [0]` and the sibling `StreamK: [0, 3]` already gets `StreamK=0` rejected by the pre-existing `UseSubtileImpl=1 supports StreamK only` gate, the `StreamK=3` PGR=0 candidate was the only survivor of SolutionStructs; its KernelWriter failure dropped `newLibrary.solutions` to 0 and the yaml raised `RuntimeError: No valid solutions found` from `ClientWriter.writeBenchmarkFiles`. Reproduction: pytest .../test_config.py -k subtile_bf16_anyk_largemt → exit=1, RuntimeError at ClientWriter.py:756. Reverting 7c0d7aa on top of e85adb5 (the merge tip) makes the test pass at exit=0. `subtile_bf16_anyk_largemt_extended.yaml` covers the same MT 320x288 shape at ASEM=2 with PrefetchGlobalRead: [0, 1, 2] (section 4) and still passes because its PGR=1 and PGR=2 variants build (2 / 3 survive KernelWriter), so the fix is structural to the yaml's candidate enumeration, not the gate. Fix: change `PrefetchGlobalRead: [0]` to `[0, 1]` in all three problem sections of `subtile_bf16_anyk_largemt.yaml`. The PGR=0 candidate is still enumerated and still hits the regressed codepath (so the `Failed to generate assembly source code` warning continues to surface in CI logs), but the PGR=1 candidate uses the `MFMASchedulerConfig.get_partition_candidates(tiA, tiB)` branch instead of the pgr==0 `candidates = [(M, N)]` branch in `Components/Subtile/Kernel.py:mainLoop`, fits the VGPR budget, and keeps `newLibrary.solutions` non-empty. The yaml docstring now documents the regression and the rationale. This is the smallest possible workaround on this branch -- the proper fix is upstream in `Components/Subtile/LogicalScheduler.py` (have the new allocator share state with the tail scaffold, or restore the tighter packing from the pre-merge allocator). Leaving that to a follow-up so this PR's CI can clear. No code changes; the Solution.py gate from c659ffd is unchanged and the 27 regression pins in test_solution_subtile_anyk.py still pass (84 total, full unit suite green). All five `subtile_bf16_anyk*` yamls now pass. Co-authored-by: Cursor <cursoragent@cursor.com>
Review feedback on `SubtileTailSrdTighten.py` flagged that the MX
scale tightener's per-K-element byte stride was derived implicitly
from subtile internals (`lrSubtileSize * lrGlobalSubtileGrid[1] /
DepthU`) rather than the M/N-direction swizzle size that
`KernelWriterAssembly.computeLoadSrd()` uses to set Srd+2 in the
first place. The two derivations evaluate to the same number for
all gauntlet configs (mxBlock=32 -> 1 byte/K-element), but the
implicit form (a) hides the swizzle-block intent of the formula,
(b) silently couples the SRD-tighten code to two subtile-state
fields that aren't part of the SRD/byte contract, and (c) would
not generalize correctly to mxBlock=16 (which yields 2 bytes/K-
element by the explicit formula but might happen to give a
different value via the legacy fields). Reviewer also asked that
the same swizzle-size table apply across all three tighteners so
the non-swizzled bf16/fp16/int8 path is the swizzleSize0=1 special
case of one shared formula.
Changes
=======
* New module-level helper `_swizzleSize0ForMN(kernel, tc)` mirrors
the swizzle branch in `KernelWriterAssembly.computeLoadSrd()`:
| path | swizzleSize0 |
|---------------------------------------------------|--------------|
| MX scale + MXScaleFormat in | 32 |
| {HostPreSwizzle, InMemorySwizzle} | |
| A/B + SwizzleTensor{A,B}=True + UseSubtileImpl=1 | 16 |
| otherwise (plain bf16/fp16/int8 A/B, | 1 |
| non-swizzled MX scale, non-subtile) | |
The shared SRD-tighten formula is `delta_bytes = delta_K *
swizzleSize0 * bpe_K / blockGran`, where `blockGran` is `mxBlock`
for MX scale and 1 otherwise, and `bpe_K` is the per-K-element
data bpe (only the non-swizzled bf16/fp16/int8 path uses bpe_K >
1; swizzled paths absorb the K-direction bpe into
`swizzleBlockSize`). Each tightener's docstring carries a
swizzle-size cross-reference back to the helper so a future
reader sees the unified picture.
* `emitTailSrdTightenSubtileMX` now derives `bytesPerKElement_MX`
explicitly as `swizzleSize0 / mxBlock` per operand. The legacy
`lrSubtileSize * lrGlobalSubtileGrid[1] / DepthU` derivation is
retained as a debug-time `assert` cross-check (only fires when
subtile state is populated, which is always the case in real
kernels); divergence between the two would indicate either a
subtile-state bug or a `computeLoadSrd()` swizzle-branch drift,
both of which should be caught loudly rather than mis-clipping
Srd+2 silently. Gauntlet configs (mxfp4 / mxfp8, mxBlock=32)
produce the same bytesPerK=1 result; the emit shape is byte-for-
byte identical (no extra shift since bpe==1 skips it).
* `emitTailSrdTightenSubtile` (bf16/fp16/int8) and
`emitTailSrdTightenSubtileMXData` (MX data non-swizzled) gain
swizzle-size cross-reference notes in their docstrings calling
out which `swizzleSize0` case they correspond to (both 1, since
both gate out the swizzled-A/B path). No behavioural change in
these two; they already used the right per-K stride (bpe).
Test fixture: `_create_kernel` in
`test_subtile_tailloop_emit.py` now sets `UseSubtileImpl=True`
and `MXScaleFormat="HostPreSwizzle"` (for fp4) /
`"NoSwizzle"` (otherwise), matching what
`subtile_mxfp4*.yaml`'s `MXScaleFormat: 1` resolves to post-
Solution. Without this, the new helper would default the
MXScaleFormat to NoSwizzle (= swizzleSize0=1 for MXSA/MXSB),
which makes `swizzleSize0 % mxBlock = 1 % 32 = 1 != 0`, causing
the MX scale tightener to bail out and the pinned emit to
disappear -- a fixture-only artefact, not a real-config issue.
New tests
=========
`TestSwizzleSize0ForMN` (8 cases): pins the {32, 16, 1} swizzle
table for every combination the helper covers --
HostPreSwizzle / InMemorySwizzle / NoSwizzle / Auto MX scale,
SwizzleTensor{A,B} with and without UseSubtileImpl, plain A/B,
asymmetric per-operand swizzle, and the default MXScaleFormat=
"NoSwizzle" fallback when the key is missing.
`TestTailSrdTightenMXBytesPerKElement::test_mx_tighten_no_shift_
for_mxblock_32`: pins that the gauntlet mxfp4 emit (mxBlock=32 ->
bytesPerK=1) skips the `delta_bytes = delta_K * bytesPerKElement_
MX` shift instruction, matching the pre-refactor behaviour for
mxBlock=32 configs. A future mxBlock=16 config (untested today)
would emit the shift -- the absence pin catches an accidental
re-introduction on the path that should stay shift-free.
Deferred
========
* `MX_PAD_K = 256` rename. The 256 is specific to the MX scale's
host pre-pad granularity (`rearrangePaddedMXScaleLayout` pads
to 8 mxBlocks = 256 K-elements); the subtile-swizzled A/B
path's K-direction swizzle granularity is 32, not 256, so a
generic `SWIZZLE_K_PAD` rename would be misleading. Reviewer
explicitly marked this as deferrable ("we can generalize it
later if we do not have time now"). Module docstring now
documents why the name stays MX-scoped over generic.
* Generalizing `emitTailSrdTightenSubtileMXData` to handle the
swizzleSize0=16 PreShuffled-MX path. That path is gated out
today and no gauntlet config exercises it; the deferred clip
would need its own delta_bytes formula and emit-shape pin to
guard against over/under-tightening. Docstring now flags the
deferral explicitly with the swizzleSize0=16 cross-reference.
Validation
==========
* `test_subtile_tailloop_emit.py`: 65 pre-existing pass + 9 new
pass = 74 total.
* Subtile unit suite (anyk_emit, solution_subtile_anyk,
tailloop_emit, solution_subtile_tailloop,
subtile_tail_per_mmak_lr): 170 passed.
* Yaml fan-out (`subtile_bf16_anyk*` + `subtile_mxfp4*` on
gfx950 MI355): 8 passed (5 bf16-anyk + 3 mxfp4 variants),
including the `largemt.yaml` PGR=1 fix from the previous
commit and the MX `tail.yaml` / `tail_smoke.yaml` which
exercise the actual `emitTailSrdTightenSubtileMX` path at
DepthU=512.
Co-authored-by: Cursor <cursoragent@cursor.com>
Summary
Extend the subtile-path bf16 tail loop in tensilelite from supporting only K values that are a multiple of 32 (
AssertSummationElementMultiple = 32) to supporting any K for bf16 —K_remain ∈ {0..31}including odd K. MX datatypes (mxfp4 / mxfp8) intentionally retain the K%32 constraint and are left unchanged.This PR stacks on
users/bnemanich/subtile-tailloop-k32-rebased(the K%32 subtile-path PR). Please merge that one first.What changed
Solution.py — relax bf16 ASEM floor for subtile
Tensile/SolutionStructs/Solution.py: the two NonDTL-tail-loop gates that previously forceNonDTLTailLoop{A,B} = Truewheneveraem * bpe % 4 != 0now skip that flip when the kernel is a gfx950 subtile kernel, so subtile bf16 keeps DTL on at low ASEM (1, 2, 4) instead of falling back to the legacy non-subtile path. TheUseSubtileImpl &= isgfx950narrowing is hoisted above this block so the gate sees the post-narrowing value (a non-gfx950 kernel with explicitUseSubtileImpl=Truenow correctly takes the legacy path withNonDTLTailLoop=True).KernelWriter.py — sub-lane K-tail mask refinement
_subtileTailByteShiftApplies(kernel, numMIInUnroll)returningTrueonly when the data path is a 2-byte-per-element subtile non-MX path withASEM < numMIInUnroll. StaticallyFalsefor ASEM ≥ 32, MX, fp4, int8, etc._emitTailSubLaneMaskRefineSubtile(...)invoked once permmakslice, after the parent's coarse per-lane cndmask and before the MFMA grid for that slice. Two steps:LoopCounterLand zero the VGPR withv_cndmask_b32if past the boundary. This handles K_remain that is even but not a multiple ofnumMIInUnroll.s_and ..., 1onLoopCounterL; when K_remain is odd, the trailing element's hi 16 bits are cleared viav_and ..., 0xffff+ a finalv_cndmask_b32 ... hi16on the boundary VGPR.dict[vIdx -> ir]with a defensive assert that pins the contract that each VGPR maps to a single K-slot under the current tile layout._emitTailLoopScaffoldSubtiledocstring and the call-site comment inkernelBodySubtileare updated to drop stale "K%32 tail loop" wording.Removed code
tailLoopGRNarrowBF16global-read helper and its re-export are removed. A narrow trailing buffer-load was prototyped earlier but is not legal on gfx950 (buffer_load_*_d16 ... lds) and is not needed for correctness —HasPartialOOB=Trueclipping plus the in-VGPR hi16 clear cover the trailing element.Tests
New unit tests (
Tensile/Tests/unit/)test_solution_subtile_anyk.py— pins the gate behavior inSolution.py:NoTailLoopstays False on TN and NT layouts.UseSubtileImpl=Truecorrectly falls back to legacy NonDTL withNonDTLTailLoop=True(regression pin for the gate-order fix).test_subtile_anyk_emit.py— pins the emission shape:v_cndmask_b32 v<dst>, v<dst>, 0, s[...](zero issrc1, mask issrc2) at the expected K-position offsets across mmak slices.SubtileTailByteHi16Skipruntime gate,v_and ..., 0xffff, and a finalv_cndmask_b32 ... hi16._subtileTailByteShiftAppliesreturns False for MX and ASEM ≥ numMIInUnroll (8-row parametrize covers bf16/fp16-True, mxfp8/mxfp4-False, asem-≥-numMIInUnroll-False, int8-False).New regression yamls (
Tensile/Tests/common/gemm/gfx950/)subtile_bf16_anyk_k8.yaml— K_remain a multiple of 8 but not 32; exercises the existing per-lane cndmask path under relaxed ASEM.subtile_bf16_anyk_k2.yaml— K_remain even but not a multiple of 8; exercises the per-VGPR refinement.subtile_bf16_anyk_odd.yaml— odd K_remain; exercises both refinement steps.Each yaml covers PGR ∈ {0, 1, 2}, StreamK ∈ {0, 3}, multiple wave groups (1×1, 2×2, 4×1), batch and asymmetric M/N.
Validation
Tensile/Tests/unit/).subtile_bf16.yaml(K%32 regression): 6287 / 6287 PASSsubtile_bf16_tail.yaml(K%32 tail regression): 450 / 450 PASSsubtile_bf16_anyk_k8.yaml: 183 / 183 PASSsubtile_bf16_anyk_k2.yaml: 183 / 183 PASSsubtile_bf16_anyk_odd.yaml: 81 / 81 PASSsubtile_mxfp4.yaml(MX regression): 2691 / 2691 PASSsubtile_mxfp4_tail.yaml(MX tail regression): 50 / 50 PASSOut of scope
K_remain % 32 == 0constraint. Extending any-K to MX would require a separate scale-padding strategy and is not part of this PR.mmakslice on the tail iteration only.Risk
_subtileTailByteShiftAppliesprevents any prelude or extra instruction from being emitted.KernelWriter.pymakes this contract explicit; if a future change tonumReadsIterCoalescedor tile mapping reuses VGPRs across K-slots, the assert will fire and surface the regression rather than silently miscompute.Reviewer checklist
Tensile/SolutionStructs/Solution.pygate ordering — confirm hoistingUseSubtileImpl &= isgfx950does not affect other downstream branches._emitTailSubLaneMaskRefineSubtileinvocation point in_emitTailLoopScaffoldSubtile— confirm placement is between the parent's coarse-cndmask dedup and the MFMA grid for thatmmak.dict[vIdx -> ir]withassert seen[vIdx] == iris appropriate.