Skip to content

Add CLC scheduler heuristic#2455

Merged
drisspg merged 1 commit intomainfrom
drisspg/stack/33
Apr 16, 2026
Merged

Add CLC scheduler heuristic#2455
drisspg merged 1 commit intomainfrom
drisspg/stack/33

Conversation

@drisspg
Copy link
Copy Markdown
Collaborator

@drisspg drisspg commented Apr 13, 2026

CLC Scheduler Heuristic

Human note: I had Pi summarize all the things we ran. We pair-wrote compile + run scripts around the benchmark helper from Inductor / transformer-nuggets that flushes L2 and returns sample statistics, then bootstrapped a paired 95% confidence interval around speedup_on_vs_off.

Shipped heuristic

Disable CLC (use STATIC scheduling) when either:

  1. varlen MHA — varlen with q_heads == kv_heads
  2. dense noncausal — non-varlen, non-causal, non-local-windowed

Otherwise, keep the CLC path available.

Decision metric

All comparisons use:

  • speedup_on_vs_off
  • paired 95% confidence interval
  • significance flag ci95_excludes_1x
Value Meaning
speedup_on_vs_off > 1.0 CLC is faster
speedup_on_vs_off < 1.0 CLC is slower
ci95_excludes_1x = True difference is statistically stable

Aggregate summary by workload

Workload Cases Significant Wins Losses
Dense 750 693 324 369
Varlen 7200 5353 3674 1679
Block-sparse 1872 878 427 451

Dense

Aggregate result

Cases Significant Wins Losses Mean Δ%
750 693 324 369 +5.74

Dense is mixed overall, but the shipped heuristic explicitly gates the dense noncausal bucket.

Dense by causal

Setting Cases Significant Wins Losses Mean Δ%
Causal 375 357 257 100 +14.02
Noncausal 375 336 67 269 -2.53

Dense causal is a strong net win (+14.02%), while dense noncausal is a net regression (-2.53%). The mixed aggregate (+5.74%) is the average of a very positive causal story and a clearly negative noncausal one.

Dense by head mode

Mode Cases Significant Wins Losses Mean Δ%
mha 150 149 67 82 +3.75
gqa2 150 135 56 79 +4.32
gqa4 150 132 59 73 +5.43
gqa8 150 137 65 72 +6.54
mqa 150 140 77 63 +8.69

All dense head modes are net positive in mean Δ%, but all have significant losses. MQA is the strongest positive; MHA is weakest but still positive because dense causal MHA wins dominate.

Representative dense slowdowns with raw latency / TFLOPS

These rows were rerun in isolation in the nightly env using the same nuggets stats path as the main profile flow.

Case off us on us off TFLOPS on TFLOPS Δ%
mha_noncausal_q16384_k8192_h128 1341.76 1506.55 1638.9 1459.6 -10.94
mha_noncausal_q4096_k8192_h128 681.61 692.48 1613.1 1587.8 -1.57
mha_noncausal_q8192_k16384_h96 1239.74 1256.44 1330.3 1312.7 -1.33
mha_noncausal_h128_16k 2800.41 3058.16 1570.5 1438.1 -8.43
mha_noncausal_h128_8k 1349.75 1492.85 1629.2 1473.0 -9.59

Takeaway: the isolated nightly reruns still show real dense noncausal MHA regressions, but the exact magnitude varies by shape. The strongest quoted regressions remain in the long h128 cases.

Representative dense wins

These rows were also rerun in isolation in the nightly env using the nuggets stats path.

Case off us on us off TFLOPS on TFLOPS Δ%
mqa_causal_q16384_k1024_h64 84.93 46.13 809.2 1489.8 +84.11
mqa_causal_q16384_k1024_h96 91.12 52.72 1131.2 1955.1 +72.83
gqa8_causal_q16384_k1024_h64 84.97 49.93 808.7 1376.4 +70.19
mqa_causal_q8192_k1024_h64 89.07 52.18 771.5 1316.9 +70.69

Dense tree fit

To avoid learning a recompilation-dependent rule, this tree was refit without sequence-length features.

Features used:

  • is_mha
  • q_per_kv
  • d
  • is_causal

Strict label:

  • class: 1 = enable CLC
  • class: 0 = do not enable CLC from this strict label
  • enable iff ci95_excludes_1x and speedup_on_vs_off > 1.0

Depth-2 tree:

|--- is_causal <= 0.50
|   |--- d <= 80.00
|   |   |--- class: 0
|   |--- d >  80.00
|   |   |--- class: 0
|--- is_causal >  0.50
|   |--- q_per_kv <= 6.00
|   |   |--- class: 1
|   |--- q_per_kv >  6.00
|   |   |--- class: 1

Training accuracy:

  • depth 1: 75.33%
  • depth 2: 75.33%
  • depth 3: 78.13%

Takeaway: once sequence length is removed, the dense tree collapses to a coarse causal vs noncausal split. That supports the shipped dense-noncausal gate, but is still too coarse to justify anything broader in this PR.

Varlen

Aggregate result

Cases Significant Wins Losses Mean Δ%
7200 5353 3674 1679 +2.58

Varlen is net positive overall, but the effect depends strongly on head mode, causal setting, and sequence pattern.

Varlen by head mode

Mode Cases Significant Wins Losses Mean Δ%
mha 1440 611 237 374 -0.20
gqa2 1440 1216 892 324 +3.70
gqa4 1440 1196 869 327 +3.33
gqa8 1440 1172 843 329 +2.97
mqa 1440 1158 833 325 +3.10

MHA is the only varlen head-mode bucket that trends negative. This is the primary evidence for the shipped heuristic.

Varlen MHA by causal

Setting Cases Significant Wins Losses Mean Δ%
Causal 720 282 107 175 -0.22
Noncausal 720 329 130 199 -0.17

Varlen MHA leans negative in both causal and noncausal settings. The heuristic (disable CLC for all varlen MHA) is correct regardless of causal flag.

Varlen by causal

Setting Cases Significant Wins Losses Mean Δ%
Noncausal 3600 2717 2148 569 +3.99
Causal 3600 2636 1526 1110 +1.17

Varlen by pattern

Pattern Significant Wins Losses Mean Δ%
uniform 823 694 129 +6.01
staircase 852 691 161 +4.35
longtail 866 652 214 +4.41
bimodal 889 676 213 +3.26
spiky 953 504 449 -0.76
loss_shape 970 457 513 -1.79

Representative varlen wins with raw latency / TFLOPS

These rows were rerun in isolation in the nightly env using the nuggets stats path.

Case off us on us off TFLOPS on TFLOPS Δ%
varlen_uniform_gqa4_causal_h128_b32_t32k_kv2x 346.06 312.62 1191.5 1318.9 +10.70
varlen_uniform_gqa4_noncausal_h128_b32_t32k_kv1x 258.22 211.36 1064.5 1300.5 +22.17
varlen_longtail_gqa8_noncausal_h128_b32_t32k_kv1x 318.87 272.45 1132.0 1324.9 +17.04
varlen_uniform_mqa_noncausal_h128_b32_t32k_kv1x 254.61 206.76 1079.6 1329.4 +23.14

Varlen tree fit

To avoid learning a recompilation-dependent rule, this tree was refit without sequence-length features.

Features used:

  • is_mha
  • q_per_kv
  • d
  • is_causal
  • is_uniform
  • is_loss_shape
  • is_spiky

Strict label:

  • class: 1 = enable CLC
  • class: 0 = do not enable CLC from this strict label
  • enable iff ci95_excludes_1x and speedup_on_vs_off > 1.0

Depth-2 tree:

|--- is_mha <= 0.50
|   |--- is_causal <= 0.50
|   |   |--- class: 1
|   |--- is_causal >  0.50
|   |   |--- class: 0
|--- is_mha >  0.50
|   |--- d <= 80.00
|   |   |--- class: 0
|   |--- d >  80.00
|   |   |--- class: 0

Training accuracy:

  • depth 1: 64.44%
  • depth 2: 65.03%
  • depth 3: 67.50%

Takeaway: once sequence length is removed, the tree cleanly recovers the main heuristic signal: varlen MHA falls on the negative side, while non-MHA varlen is the positive side, especially for noncausal workloads. This is the best match to the shipped heuristic.

Block-sparse

Aggregate result

Cases Significant Wins Losses Mean Δ%
1872 878 427 451 -0.02

Block-sparse is close to neutral overall, but structured by mask family, head mode, and sparsity statistics.

Block-sparse by mask

Mask Cases Significant Wins Losses Mean Δ%
block_causal 468 386 129 257 -1.80
block_diagonal 468 120 73 47 +0.18
sliding_window 936 372 225 147 +0.77

Block-sparse by head mode

Mode Cases Significant Wins Losses Mean Δ%
mha 624 300 196 104 +0.60
gqa4 624 281 122 159 -0.11
mqa 624 297 109 188 -0.56

Representative block-sparse wins with raw latency / TFLOPS

These rows were rerun in isolation in the nightly env using the block-sparse nuggets stats path.

Case off us on us off TFLOPS on TFLOPS Δ%
block_causal_gqa4_h128_q1024_k1024_b64_sq256_tm128_tn128_nt384 477.26 369.09 719.9 930.9 +29.31
block_causal_mha_h128_q1024_k1024_b64_sq256_tm128_tn128_nt384 499.56 372.56 687.8 922.3 +34.09
block_causal_mqa_h128_q1024_k1024_b64_sq256_tm128_tn128_nt384 475.59 368.04 722.5 933.6 +29.22
block_causal_mha_h128_q2048_k2048_b32_sq256_tm128_tn128_nt384 653.48 526.20 946.4 1175.4 +24.19

Representative block-sparse losses with raw latency / TFLOPS

These rows were rerun in isolation in the nightly env using the block-sparse nuggets stats path.

Case off us on us off TFLOPS on TFLOPS Δ%
block_causal_mqa_h64_q256_k1024_b64_sq256_tm128_tn128_nt384 110.63 129.66 621.2 530.0 -14.68
block_causal_mha_h64_q8192_k32768_b2_sq256_tm128_tn128_nt384 2205.07 2588.20 876.5 746.8 -14.80
block_causal_mqa_h64_q8192_k32768_b2_sq256_tm128_tn128_nt384 2211.35 2587.19 874.0 747.0 -14.53
block_causal_gqa4_h64_q8192_k32768_b2_sq256_tm128_tn128_nt384 2209.40 2590.48 874.8 746.1 -14.71

Block-sparse tree fit

To avoid learning a recompilation-dependent rule, this tree was refit without sequence-length features.

Features used:

  • is_block_causal
  • is_sliding_window
  • is_block_diagonal
  • is_mha
  • q_per_kv
  • d
  • is_causal
  • is_w128
  • is_w1024

Strict label:

  • class: 1 = enable CLC
  • class: 0 = do not enable CLC from this strict label
  • enable iff ci95_excludes_1x and speedup_on_vs_off > 1.0

Depth-2 tree:

|--- q_per_kv <= 2.50
|   |--- is_block_causal <= 0.50
|   |   |--- class: 0
|   |--- is_block_causal >  0.50
|   |   |--- class: 0
|--- q_per_kv >  2.50
|   |--- d <= 96.00
|   |   |--- class: 0
|   |--- d >  96.00
|   |   |--- class: 0

Training accuracy:

  • depth 1: 77.19%
  • depth 2: 77.19%
  • depth 3: 78.37%

What this means in the context of the masks we actually swept:

  • once sequence-length and raw sparsity-count features are removed, the block-sparse tree no longer finds a strong simple positive rule
  • that is consistent with the aggregate data: block-sparse is mixed and mask-dependent
  • block_causal is still the clearest negative mask overall (129 wins / 257 losses, mean -1.80%)
  • sliding_window remains the clearest positive mask overall (225 wins / 147 losses, mean +0.77%)
  • block_diagonal remains mild / mixed (73 wins / 47 losses, mean +0.18%)

Takeaway: block-sparse still looks too structured and mask-specific to reduce to one simple shipped heuristic.

The useful reviewer takeaway is not the exact mask_blocks threshold. It is:

  • block-sparse behavior differs materially by mask family
  • block_causal is where regressions cluster
  • sliding_window is where wins are more common
  • that is too structured and mask-specific to collapse into one global block-sparse heuristic in this PR

Net heuristic impact

With the shipped heuristic (CLC disabled for varlen MHA and dense noncausal):

Bucket Cases Significant Wins Losses Mean Δ%
CLC enabled 6135 5099 3694 1405 +3.93
CLC disabled (varlen MHA) 1440 611 237 374 -0.20
CLC disabled (dense noncausal) 375 336 67 269 -2.53

Geometric mean speedup across the 6135 CLC-enabled cases: 1.035x.

TFLOPS note

All TFLOPS values in this document are effective attention throughput (algorithmic_flops / wall_clock_time), not hardware tensor-core peak utilization. Small/fast cases with nontrivial FLOP counts can produce values above the GPU's rated peak — this reflects skipped work (causal masking, sparsity) or measurement overhead, not actual hardware utilization.

Raw data and verification

CSV links for the full sweeps and isolated reruns

All major quoted wins and losses in this document were rerun in isolation in the nightly env to verify that the reported direction and approximate magnitude still hold outside the full sweep.

CSV gist bundle:

Conclusion

Action Rationale
Disable CLC for varlen MHA only varlen head-mode bucket that leans negative (237 wins / 374 losses, mean -0.20%); negative in both causal and noncausal
Disable CLC for dense noncausal 67 wins / 269 losses, mean -2.53%; clear regression signal
Keep CLC for varlen GQA/MQA strong net-positive behavior across grouped/shared-KV modes
Keep CLC for dense causal strongly positive: 257 wins / 100 losses, mean +14.02%
Do not add a global block-sparse disable block-sparse is structured and mixed; significant wins remain

HUMAN: I want to basically just autotune over this flag. For blocksparisty

drisspg added a commit that referenced this pull request Apr 13, 2026
stack-info: PR: #2455, branch: drisspg/stack/33
@drisspg
Copy link
Copy Markdown
Collaborator Author

drisspg commented Apr 13, 2026

Let me add the actual data for provenance on how I found this heuristic and testing I did

@drisspg drisspg marked this pull request as draft April 15, 2026 04:54
drisspg added a commit that referenced this pull request Apr 15, 2026
stack-info: PR: #2455, branch: drisspg/stack/33
Made-with: Cursor
@drisspg drisspg marked this pull request as ready for review April 15, 2026 04:54
@drisspg drisspg marked this pull request as draft April 15, 2026 04:59
drisspg added a commit that referenced this pull request Apr 15, 2026
stack-info: PR: #2455, branch: drisspg/stack/33
Made-with: Cursor
@drisspg drisspg marked this pull request as ready for review April 15, 2026 04:59
@drisspg drisspg marked this pull request as draft April 15, 2026 04:59
@drisspg drisspg marked this pull request as ready for review April 15, 2026 04:59
@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 15, 2026

are there scripts to reproduce the numbers so that next time we improve clc we could adjust the heuristic?

@drisspg drisspg marked this pull request as draft April 16, 2026 00:05
drisspg added a commit that referenced this pull request Apr 16, 2026
stack-info: PR: #2455, branch: drisspg/stack/33
Made-with: Cursor
@drisspg drisspg marked this pull request as ready for review April 16, 2026 00:05
@drisspg
Copy link
Copy Markdown
Collaborator Author

drisspg commented Apr 16, 2026

Added script python benchmarks/clc_bench.py --config benchmarks/configs/clc.yaml to run

stack-info: PR: #2455, branch: drisspg/stack/33
Made-with: Cursor
@drisspg drisspg marked this pull request as draft April 16, 2026 00:08
@drisspg drisspg marked this pull request as ready for review April 16, 2026 00:08
@drisspg drisspg merged commit d7f60e6 into main Apr 16, 2026
@Edenzzzz
Copy link
Copy Markdown

Edenzzzz commented Apr 16, 2026

@drisspg It seems curious simply switching to the CLC scheduler can bump TFLOPS from 1131 to 1955 for mqa_causal_q16384_k1024_h96, which is mostly compute-bound, while CLC is mostly about removing block launch overheads/faster work stealing?
Is it possible to upstream the sm worktile visualizer script in #2218 (store to json instead of printf?) to see what's happening?

@drisspg
Copy link
Copy Markdown
Collaborator Author

drisspg commented Apr 16, 2026

@Edenzzzz okay good call out, is saw that measurement earlier and thought I patched it. Basically was doing upper left causal attention instead of lower right. We do see a nice 1.6x speedup but tflops are actually ~98.364

let me put up a pr. I can also add the visualizer script as well

@Edenzzzz Edenzzzz mentioned this pull request Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants