Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9fdf5c4
Add Pallas/Triton forward kernel for linear softmax cross-entropy loss
captainpete Mar 23, 2026
7d70a11
Add Pallas/Triton backward kernel for linear softmax cross-entropy loss
captainpete Mar 23, 2026
ea65851
Add Pallas/Triton Op wiring for linear softmax cross-entropy loss
captainpete Mar 23, 2026
8763377
Add GPU benchmark harness and update README for linear_softmax_cross_…
captainpete Mar 23, 2026
843b698
Rewrite backward to O(3BVH) single kernel with zero-init aliasing
captainpete Mar 24, 2026
068ad48
Fuse dout scaling into backward kernel
captainpete Mar 24, 2026
7c7c4e8
Doc: add triton to api.py docstring, retire stale backend='triton' go…
captainpete Mar 24, 2026
de4cba8
Add Pallas/Mosaic-GPU SM90 Op for linear softmax cross-entropy loss
captainpete Mar 25, 2026
97dc4e3
Register LCE loss benchmark in CI; add mosaic_gpu to benchmark implem…
captainpete Mar 25, 2026
17f409b
Fix benchmark EXAMPLES: rename 'w' key to 'weights' to match public API
captainpete Mar 25, 2026
eca595f
Switch mosaic_gpu and triton backward to padded-chunk cuBLAS scan
captainpete Mar 25, 2026
d04a00d
Clarify Triton exclusion: forward segfault remains, backward is resolved
captainpete Mar 25, 2026
c1ae60b
Fix stale docstrings and PR.md backend selection order
captainpete Mar 25, 2026
433f9a2
Remove dead backward kernels (Triton atomic_add bwd, SM90 WGMMA bwd)
captainpete Mar 25, 2026
74d0125
Remove mosaic_gpu from default backend chain; make it explicit opt-in
captainpete Mar 25, 2026
ad2e1b0
Switch Triton to heuristics config; drop autotuning
captainpete Mar 25, 2026
4bd67b1
Improve Triton heuristics config; clean up PR.md
captainpete Mar 25, 2026
c88af86
Triton: v-padding, heuristic overhaul, memory story in PR.md
captainpete Mar 27, 2026
90f811c
Another pass on the PR.md
captainpete Mar 27, 2026
9b7f2c8
Remove PR doc
captainpete Mar 27, 2026
6e20df0
Remove uv.lock
captainpete Mar 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ We currently support the following GPU kernels:

And the following for both GPU and TPU:

* `tokamax.linear_softmax_cross_entropy_loss`
([Memory Efficient Linear Cross Entropy Loss Kernel](https://arxiv.org/abs/2410.10989v2)).
* `tokamax.ragged_dot`
([Mixture of Experts](https://arxiv.org/abs/2211.15841)).

And the following TPU kernels:

* `tokamax.linear_softmax_cross_entropy_loss`
([Memory Efficient Linear Cross Entropy Loss Kernel](https://arxiv.org/abs/2410.10989v2))

## Installation

The latest Tokamax [PyPI release](https://pypi.org/project/tokamax/):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ requires-python = ">=3.11"
dependencies = [
"absl-py>=2.3.0",
"einshape",
"jax>=0.9.2",
"jax[cuda12]>=0.9.2",
"jaxlib>=0.9.2",
"jaxtyping>=0.3",
"pydantic>=2.11.0",
Expand Down
64 changes: 44 additions & 20 deletions tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,22 @@
from tokamax._src.ops.linear_softmax_cross_entropy_loss import base


Implementation: TypeAlias = Literal["mosaic_tpu", "xla"]
Implementation: TypeAlias = Literal["mosaic_gpu", "mosaic_tpu", "triton", "xla"]

IMPLEMENTATIONS = dict(xla=base.LinearSoftmaxCrossEntropyLoss())
_DEFAULT_IMPLEMENTATION = ("xla",)

try:
from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_triton # pylint: disable=g-import-not-at-top # pytype: disable=import-error

IMPLEMENTATIONS["triton"] = (
pallas_triton.PallasTritonLinearSoftmaxCrossEntropyLoss()
)

_DEFAULT_IMPLEMENTATION = ("triton",) + _DEFAULT_IMPLEMENTATION
except ImportError:
pass

try:
from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_mosaic_tpu # pylint: disable=g-import-not-at-top # pytype: disable=import-error

Expand All @@ -38,6 +49,21 @@
except ImportError:
pass

try:
from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_mosaic_gpu # pylint: disable=g-import-not-at-top # pytype: disable=import-error

IMPLEMENTATIONS["mosaic_gpu"] = (
pallas_mosaic_gpu.PallasMosaicGpuLinearSoftmaxCrossEntropyLoss()
)

# mosaic_gpu is NOT added to _DEFAULT_IMPLEMENTATION. Its forward is at XLA
# parity but its backward is ~3× slower (chunked scan over V vs two full-width
# cuBLAS matmuls). The benefit is memory: the (B, V) logit matrix is never
# materialised. Use implementation='mosaic_gpu' explicitly when the logit
# matrix would OOM the device.
except ImportError:
pass


def linear_softmax_cross_entropy_loss(
x: Real[Array, "B H"],
Expand Down Expand Up @@ -72,10 +98,15 @@ def linear_softmax_cross_entropy_loss(
precision: The precision used for jax.lax.dot_general for the linear
projection and gradient calculation.
implementation: By default "None" will be used to pick the best available
backend. Can be set to "xla" or "mosaic_tpu" explicitly. The "mosaic_tpu"
implementation is memory efficient and has almost 0 additional buffer
overhead while the "xla" implementation needs to materialize the full
logits
backend. Can be set to "xla", "mosaic_tpu", "triton", or "mosaic_gpu"
explicitly. The default selection order is mosaic_tpu → triton → xla,
with each backend skipped if unavailable on the current device.
"mosaic_gpu" is available on H100+ (SM90) but is not in the default
chain: its forward is at XLA parity but its backward is ~3× slower due
to chunked-scan accumulation. Use implementation='mosaic_gpu' explicitly
when the (B, V) logit matrix would OOM the device — that is the intended
use case. "mosaic_tpu" and "triton" are memory-efficient and avoid
materialising the full logit matrix.

Returns:
The Cross-Entropy loss
Expand All @@ -91,21 +122,16 @@ def linear_softmax_cross_entropy_loss(
"Customization of precision is currently not supported."
)

if implementation is not None:
if implementation in IMPLEMENTATIONS:
loss = IMPLEMENTATIONS[implementation](
x,
labels,
weights,
reduction=reduction,
)
return loss
else:
raise ValueError(f"Unsupported implementation: {implementation}")
if implementation is None:
implementation = _DEFAULT_IMPLEMENTATION

if not isinstance(implementation, (tuple, list)):
implementation = (implementation,)

# Find out the best impelmentation based on the hardware.
errors = []
for impl in IMPLEMENTATIONS:
for impl in implementation:
if impl not in IMPLEMENTATIONS:
raise ValueError(f"Unsupported implementation: {impl}")
try:
loss = IMPLEMENTATIONS[impl](
x,
Expand All @@ -115,8 +141,6 @@ def linear_softmax_cross_entropy_loss(
)
return loss
except NotImplementedError as e:
if len(implementation) == 1:
raise
errors.append(e)

raise ExceptionGroup("all implementations failed", errors)
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pallas-Mosaic-GPU Op implementation of linear softmax cross-entropy loss.

Forward pass: SM90 WGMMA + TMA kernel (H100+).
Backward pass: chunked scan over V using cuBLAS GEMMs (no atomics, near-XLA speed).
"""

from dataclasses import dataclass
from typing import ClassVar, Literal

import jax
import jax.numpy as jnp
from jax.extend import backend
from jaxtyping import Array, Integer, Real, Scalar
from tokamax._src import gpu_utils
from tokamax._src.ops import op
from tokamax._src.ops.linear_softmax_cross_entropy_loss import base
from tokamax._src.ops.linear_softmax_cross_entropy_loss import (
pallas_mosaic_gpu_common as common,
)
import tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_mosaic_gpu_kernel_sm90 as kernel_sm90
from typing_extensions import override


Config = common.Config
Key = common.Key


def linear_softmax_cross_entropy_loss_bwd_chunked_scan(
dout,
lse,
x,
labels,
w,
*,
reduction,
chunk_size=4096,
):
"""Chunked-scan backward: padded chunks for full cuBLAS utilisation.

Uses chunk_size-wide GEMMs throughout — the last chunk is zero-padded and
masked so padded positions contribute nothing to either gradient. This gives
square GEMMs for any vocab size (including irregular sizes like V=128256).
Never materialises the full (B, V) logit matrix.
"""
b_dim, h_dim = x.shape
v_dim = w.shape[1]

x_f32 = x.astype(jnp.float32)
w_f32 = w.astype(jnp.float32)
lse_f32 = lse.astype(jnp.float32)
scale = (
dout.astype(jnp.float32) / b_dim
if reduction == "mean"
else dout.astype(jnp.float32)
)

num_chunks = (v_dim + chunk_size - 1) // chunk_size
v_padded = num_chunks * chunk_size

# Pad w to v_padded (last chunk may be partial; extra cols are zero).
w_padded = jnp.pad(w_f32, ((0, 0), (0, v_padded - v_dim))) # (H, v_padded)
# Reshape into (num_chunks, H, chunk_size) for scan.
w_chunks = w_padded.reshape(h_dim, num_chunks, chunk_size).transpose(1, 0, 2)

def scan_fn(x_grad_carry, args):
chunk_idx, w_chunk = args # w_chunk: (H, chunk_size)
v_start = chunk_idx * chunk_size
logit_chunk = x_f32 @ w_chunk # (B, chunk_size)
softmax_chunk = jnp.exp(logit_chunk - lse_f32[:, None])
col_idx = jnp.arange(chunk_size) + v_start
one_hot_chunk = (col_idx[None, :] == labels[:, None]).astype(jnp.float32)
# Zero out padded positions so they don't contribute to either gradient.
valid = (col_idx < v_dim).astype(jnp.float32)[None, :]
s_chunk = scale * (softmax_chunk - one_hot_chunk) * valid
x_grad_carry = x_grad_carry + s_chunk @ w_chunk.T # (B, H)
w_grad_chunk = x_f32.T @ s_chunk # (H, chunk_size)
return x_grad_carry, w_grad_chunk

x_grad, w_grad_chunks = jax.lax.scan(
scan_fn,
jnp.zeros((b_dim, h_dim), dtype=jnp.float32),
(jnp.arange(num_chunks), w_chunks),
)
# w_grad_chunks: (num_chunks, H, chunk_size) → (H, v_padded) → (H, V)
w_grad = w_grad_chunks.transpose(1, 0, 2).reshape(h_dim, v_padded)[:, :v_dim]
return x_grad, w_grad


def _mosaic_vjp(
residuals: base.Residuals,
out: jax.Array,
dout: jax.Array,
x: jax.Array,
labels: jax.Array,
w: jax.Array,
*,
reduction: str = "sum",
return_residuals: bool = False,
):
"""Mosaic GPU backward: chunked scan over V (no atomics, cuBLAS per chunk)."""
del out, return_residuals
(lse,) = residuals
x_grad, w_grad = linear_softmax_cross_entropy_loss_bwd_chunked_scan(
dout,
lse,
x,
labels,
w,
reduction=reduction,
)
labels_grad = jnp.zeros_like(labels)
return (x_grad, labels_grad, w_grad)


@dataclass(frozen=True, kw_only=True)
class PallasMosaicGpuLinearSoftmaxCrossEntropyLoss(
base.LinearSoftmaxCrossEntropyLoss[Config]
):
"""Pallas/Mosaic-GPU SM90 forward + backward for linear softmax CE loss.

Forward: SM90 WGMMA + TMA kernel (H100+).
Backward: chunked scan over V using cuBLAS GEMMs (no atomics, no WGMMA).
"""

config_cls: ClassVar[type[Config]] = Config

def __post_init__(self):
object.__setattr__(self, "vjp", _mosaic_vjp)

@override
def _fwd(
self,
x: Real[Array, "B H"],
labels: Integer[Array, "B"],
w: Real[Array, "H V"],
*,
reduction: Literal["sum", "mean"] = "sum",
config: Config,
return_residuals: bool,
) -> tuple[jax.Array, base.Residuals]:
device_kind = backend.get_default_device().device_kind.lower()
if not (gpu_utils.is_sm90() or gpu_utils.is_sm100()):
raise NotImplementedError(
f"Mosaic GPU kernel requires SM90 or SM100; got {device_kind!r}."
)

loss, lse = kernel_sm90.linear_softmax_cross_entropy_loss_fwd_pallas_mosaic_gpu_sm90(
x,
labels,
w,
tile_m=config.tile_m,
tile_n=config.tile_n,
tile_k=config.tile_k,
num_stages=config.num_stages,
reduction=reduction,
)
return loss, (lse,)

@override
def _get_heuristics_config(self, ba: op.BoundArguments) -> Config:
return common.get_heuristics_config(ba.arguments["x"], ba.arguments["w"])

@override
def _get_autotuning_configs(self, ba: op.BoundArguments) -> set[Config]:
return common.get_autotuning_configs(ba.arguments["x"], ba.arguments["w"])

@override
def _get_autotuning_cache_key(self, ba: op.BoundArguments) -> Key:
return common.get_key(**ba.arguments)

@override
def supported_on(self, device: jax.Device) -> bool:
return gpu_utils.has_mosaic_gpu_support(device)
Loading