Triton and Mosaic for linear_softmax_cross_entropy_loss#801
Open
captainpete wants to merge 21 commits into
Open
Triton and Mosaic for linear_softmax_cross_entropy_loss#801captainpete wants to merge 21 commits into
captainpete wants to merge 21 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi there!
This PR adds GPU backends for
linear_softmax_cross_entropy_loss, which previously only ran on TPU.Keeping with the motivation of reducing memory footprint through sacrificing some speed.
This was a fun one to work on over the last few weeks. I've taken a few passes at this and reached the point where further improvements are not obvious to me, so raising as a PR for your review.
Appreciate your attention on this, please let me know if there's anything to address.
I've added the write-up below, key sections of interest:
Including these case they answer questions raised:
Overview
XLA's implementation materialises the full
(B, V)logit matrix. At LLM scale this is large:During training, this allocation sits alongside activations, weights, and optimiser state.
Both kernels in this PR use the tiled algorithm from Liger et al. (2024), which tiles over
(b_tile, v_tile)pairs and keeps logits only in registers; peak logit memory drops fromO(B*V)toO(b_block*v_block), a few KB regardless of vocab size.The trade-off to this approach is speed. XLA's single cuBLAS GEMM is compute-bound and hard to match with a tiled kernel. Much better than the PyTorch baseline from the paper.
These kernels are slower (see Performance) and should be used when the logit matrix is the binding memory constraint, not as a general replacement for XLA.
Also added a benchmark harness registered in
benchmark_registry.pbtxt(H100, B200, TPU-v6e, TPU-v7) and updated the README.Triton (
pallas_triton_*)SM80+ (Ampere and up). Selected automatically on GPU when Triton is available. Forward and backward; float32 accumulation throughout. ~2x XLA forward wall-clock time on LLM-scale shapes.
Mosaic GPU SM90 (
pallas_mosaic_gpu_*)H100+ (SM90). WGMMA + TMA pipelining; two warp groups per CTA. Not selected by default. Forward within ~5% of XLA; backward 4-8x slower (chunked cuBLAS scan over V).
Use explicitly when the logit matrix would OOM and the backward cost is acceptable.
Algorithm
Both kernels tile over
(b_cta, v)pairs and computex[b_tile,:] @ w[:,v_tile]on-chip, accumulating per-token logsumexp:The correct-class logit is computed outside the kernel as a cheap
O(B*H)XLA einsum (jnp.einsum("bh,hb->b", x, w[:, labels])), avoiding a gather inside the kernel (awkward with TMA).The backward recomputes logit tiles on-the-fly rather than storing them (recompute-for-backward, as in FlashAttention).
Memory
XLA allocates the full
(B, V)logit tensor in HBM (float32 for numerical stability), then reads it again for the logsumexp and CE loss reduction. Both kernels here eliminate this:Forward: each
(b_block, v_block)logit tile lives in registers for the duration of one kernel invocation.No HBM allocation for logits at any point.
The outputs written to HBM are
(B, num_v_blocks), a per-token, per-v-chunk logsumexp and correct-logit contribution,O(B)rather thanO(B*V).Backward: logit tiles are recomputed from
xandwon the fly, one chunk at a time, and discarded. The peak extra allocation during the backward is one logit chunk(B, chunk_size), which ends up being a few MB.The residual saved from forward to backward is the per-token log-sum-exp
lse, shape(B,).For reference, the
(B, V)logit tensor that these kernels avoid:XLA computes in float32 regardless of input dtype (bfloat16 inputs are upcast before the GEMM), so the relevant number is the float32.
During benchmarking, XLA's forward for qwen3-8b hit
RESOURCE_EXHAUSTED(48 MB allocation failure) at high memory pressure, where the tiled kernels succeeded.Performance
Benchmarked on H100 (bfloat16 inputs,
meanreduction).Triton forward numbers below are from RTX 3090 (same heuristic, same pattern expected on H100, but I didn't didn't have access to the hardware for long enough); H100 Triton numbers TBD.
Median wall-clock time (ms)
H100 numbers (XLA and
mosaic_gpu); RTX 3090 numbers (Triton, where available):XLAfwdmosaic_gpufwdtritonfwdXLAfwd+vjpmosaic_gpufwd+vjpRTX 3090 Triton forward results (H100 benchmarks pending):
XLAfwd (3090)tritonfwd (3090)Interpretation
Forward:
mosaic_gpuis within ~5% of XLA across all shapes.tritonforward runs at ~2x XLA wall-clock time. This is expected and close to the theoretical minimum for the tiling approach: Triton re-readsxonce per v-chunk andwonce per b-chunk, accumulatingB*H*V/128elements from each, while XLA's cuBLAS readsxandwonce in a single compute-bound GEMM. The heuristic balances x/w HBM traffic (b_block = v_block = 128when B is divisible by 128). Closing the gap further would requirev_block > 128, which is blocked by the JAX 0.9.2 Triton compiler limitation (described in implementation notes).Backward:
mosaic_gpuis 4-8x slower, scaling withceil(V / 4096)(the number of sequential cuBLAS chunk iterations).Total FLOP count is identical to XLA; the overhead is that XLA issues two full-width matmuls while the chunked scan issues 32-64 sequential ones.
For the shapes above on an H100 (80 GB), XLA fits comfortably.
On devices with smaller HBM (A100 40 GB, RTX 3090 24 GB) or at higher batch sizes the logit tensor becomes the constraint; see Memory.
Precision
In practice, LLM training uses bfloat16 inputs and
meanreduction, the common case in the first column, where all backends agree toatol=2e-2.The float32/sum column is the worst case.
The SM90 forward kernel down-casts float32 inputs to bf16 for WGMMA (hardware requirement), introducing quantisation noise of up to ~0.4 per gradient element for unit-variance inputs, uniform across gradient magnitudes.
The backward uses cuBLAS in float32 throughout, so the full tolerance budget comes from the forward's bf16 down-cast.
The initial results led me down a few rabbit holes, but I've confirmed it's the bf16 down-cast that causes the sum accum tol discrepancy.
Implementation notes
Triton backend
Matmul accumulates in
float32throughout (Triton handles this natively withjnp.float32dot).This gives good numerical accuracy; gradients match the XLA reference at
atol=2e-2.The backward fuses the gradient scale (
dout / Bfor mean,doutfor sum) into the kernel rather than applying it post-hoc, saving one pass over the output tensors.Tiling heuristic
HBM traffic for the forward pass scales as:
xtraffic:B * H * V / v_block(x is re-read once per v-chunk tile)wtraffic:B * H * V / b_block(w is re-read once per b-chunk tile)Traffic is balanced when
b_block = v_block.At
v_block=128(the maximum safe value), the heuristic targetsb_block=128whenBis divisible by 128, which equalises x/w HBM reads and measurably improves performance (~4% on LLM-scale shapes).Register budget on SM80+ (65536 regs/SM,
num_warps=4, 128 threads/CTA):With
b=128,his capped at 64 to stay within the 50% budget (2 CTAs/SM).With
b <= 64,h=128is used whenHis divisible by 128 for better tensor-core tile efficiency;h_blockdoes not affect HBM traffic.v_block_sizecap at 128v_block_size=256crashes the Triton-to-PTX compilation stage in JAX 0.9.2's bundled Triton with a C++ exception (segfault inf.compile()).JAX's
pallas/triton/lowering.pydocuments this as the power-of-2 tensor-size check (line 288-301) applies only to load/store ops and explicitly notes that for other ops "the Triton lowering will fail anyway but it will crash with a C++ exception".With a (32, 256) accumulator tile, the load/store check passes (8192 = 2^13) but the Triton backend then crashes during instruction selection for
tl.dot.I didn't find an upstream issue this specific case (float32
tl.dotwith N=256 on SM80 in JAX's bundled Triton).The closest related fix is jax-ml/jax#35654, which added an early guard for the same crash pattern in the fp64 MMA path; the fp32/n=256 case is not yet guarded.
The heuristic caps
v_block_sizeat 128 and could berevisited when JAX upgrades the bundled Triton.Mosaic GPU SM90 backend
Uses
plgpu.emit_pipeline_warp_specializedwith two warp groups per CTA.One warp group handles rows
[0, tile_m), the other[tile_m, 2*tile_m).The pipeline loads
xandwtiles into SMEM via TMA and issues WGMMA.Float32 inputs are downcast to bf16 before entering the kernel: SM90 WGMMA only supports bf16/fp8 inputs. The accumulator remains float32.
Forward
H100 provides 227 KB shared memory per SM.
The forward kernel at 4 stages and
tile_n=128,tile_k=64uses ~129 KB.Configs at
tile_n=256ortile_k=128are reachable by the forward autotuner;the backward is unaffected (it runs in XLA, not inside the SM90 kernel).
The autotuning config generator (
get_autotuning_configs) does not currently filter configs by SMEM budget.Backward
The backward does not use the SM90 WGMMA kernel.
Instead it uses a
jax.lax.scanover padded vocabulary chunks, issuing one pair of cuBLAS GEMMs per chunk:The last chunk is zero-padded so
chunk_size(4096) divides cleanly for any vocab size (including irregular sizes like V=128256).Padded positions are masked by
valid = (col_idx < v_dim)and contribute nothing.This avoids the
atomic_addserialisation of a naive in-kernel backward that ended up adding far too much latency.Total FLOP count matches XLA; overhead is 32-38 sequential cuBLAS launches vs XLA's 2 full-width matmuls.
Files
pallas_triton_kernel.pypallas_triton_config.pypallas_triton.pypallas_triton_kernel_test.pypallas_triton_test.pypallas_mosaic_gpu_kernel_sm90.pypallas_mosaic_gpu_common.pypallas_mosaic_gpu.pypallas_mosaic_gpu_kernel_sm90_test.pypallas_mosaic_gpu_test.pyapi.pybenchmarks/linear_softmax_cross_entropy_loss.pyFuture work
supported_onpermits SM100 for the Mosaic backend (same SM90 kernels), but I haven't tested it.get_autotuning_configs. A follow-up could add asmem_bytescheck there.