Skip to content

[Cute,Sm100,Bwd] refine bwd swizzle for deterministic#2390

Merged
jayhshah merged 1 commit intomainfrom
jshah/bwd-det-swizzle
Mar 25, 2026
Merged

[Cute,Sm100,Bwd] refine bwd swizzle for deterministic#2390
jayhshah merged 1 commit intomainfrom
jshah/bwd-det-swizzle

Conversation

@jayhshah
Copy link
Copy Markdown
Collaborator

This PR fixes the varlen swizzle num_n_blocks from num_m_blocks calculation for 2cta and includes dqaccum in the head size for not spilling l2. This noticeably improves backward deterministic FLOPs (for non-varlen, limited to non-causal but doesn't appear to cause regression for causal; for varlen, improve perf across the board thanks to the fix).

Example benchmarks:

non-varlen deterministic, MHA nheads = 16

hdim causal batch seqlen     PR          		MAIN
----------------------------------------------------------------
128  False     4   8192      4.58/1200/53.3% 	4.75/1157/51.4%
128  False     2  16384      9.32/1180/52.4%	10.08/1091/48.5%
128  False     1  32768     19.69/1117/49.6%	20.48/1074/47.7%

128   True     4   8192      2.50/1100/48.9%	2.50/1100/48.9%
128   True     2  16384      4.60/1195/53.1%	4.60/1195/53.1%
128   True     1  32768      9.74/1129/50.2%	9.86/1115/49.6%

varlen deterministic, MHA nheads = 16

hdim causal batch seqlen     PR          		MAIN
----------------------------------------------------------------
128  False     4   8192      4.66/1180/52.5%	5.13/1071/47.6%
128  False     2  16384      8.97/1225/54.5%	10.15/1083/48.1%
128  False     1  32768     19.45/1131/50.3%	20.85/1054/46.9%

128   True     4   8192      2.59/1061/47.1%	2.83/973/43.2%
128   True     2  16384      4.72/1166/51.8%	5.06/1086/48.3%
128   True     1  32768      9.72/1131/50.3%	10.09/1089/48.4%

num_n_blocks = (
num_m_blocks
* params.tile_shape_mn[0]
* params.cluster_shape_m
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this affect any of the 2cta bwd code?

Copy link
Copy Markdown
Collaborator Author

@jayhshah jayhshah Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is meant to get the right head swizzle heuristic for 2cta bwd, by accounting for the num_m_blocks being defined with respect to tiler divided by the cluster shape.

num_n_blocks here is only used to derive nheads_in_l2, so it doesn't affect correctness.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the tile_shape_mn[0] do we pass the CTA's tile shape or the cluster tile shape? I think we had this discussion and realized we have not been consistent.
In any case if it doesn't affect correctness it's fine w me.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently for the tile scheduler args we pass cta_tiler[:2] as tile_shape_min and cluster_shape_mn as a separate parameter; this makes the most sense to me since cluster shape is in principle separate from use of 2cta mma.

@jayhshah jayhshah merged commit 5c7711e into main Mar 25, 2026
2 of 3 checks passed
@jayhshah jayhshah deleted the jshah/bwd-det-swizzle branch March 25, 2026 17:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants