[Cute,Sm100,Bwd] refine bwd swizzle for deterministic#2390
Conversation
| num_n_blocks = ( | ||
| num_m_blocks | ||
| * params.tile_shape_mn[0] | ||
| * params.cluster_shape_m |
There was a problem hiding this comment.
does this affect any of the 2cta bwd code?
There was a problem hiding this comment.
This change is meant to get the right head swizzle heuristic for 2cta bwd, by accounting for the num_m_blocks being defined with respect to tiler divided by the cluster shape.
num_n_blocks here is only used to derive nheads_in_l2, so it doesn't affect correctness.
There was a problem hiding this comment.
For the tile_shape_mn[0] do we pass the CTA's tile shape or the cluster tile shape? I think we had this discussion and realized we have not been consistent.
In any case if it doesn't affect correctness it's fine w me.
There was a problem hiding this comment.
Currently for the tile scheduler args we pass cta_tiler[:2] as tile_shape_min and cluster_shape_mn as a separate parameter; this makes the most sense to me since cluster shape is in principle separate from use of 2cta mma.
This PR fixes the varlen swizzle num_n_blocks from num_m_blocks calculation for 2cta and includes dqaccum in the head size for not spilling l2. This noticeably improves backward deterministic FLOPs (for non-varlen, limited to non-causal but doesn't appear to cause regression for causal; for varlen, improve perf across the board thanks to the fix).
Example benchmarks:
non-varlen deterministic, MHA nheads = 16
varlen deterministic, MHA nheads = 16