diff --git a/README.md b/README.md index 7104991c..8a012ee9 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,11 @@ We currently support the following GPU kernels: And the following for both GPU and TPU: +* `tokamax.linear_softmax_cross_entropy_loss` + ([Memory Efficient Linear Cross Entropy Loss Kernel](https://arxiv.org/abs/2410.10989v2)). * `tokamax.ragged_dot` ([Mixture of Experts](https://arxiv.org/abs/2211.15841)). -And the following TPU kernels: - -* `tokamax.linear_softmax_cross_entropy_loss` - ([Memory Efficient Linear Cross Entropy Loss Kernel](https://arxiv.org/abs/2410.10989v2)) - ## Installation The latest Tokamax [PyPI release](https://pypi.org/project/tokamax/): diff --git a/pyproject.toml b/pyproject.toml index 090dd259..7926e692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ requires-python = ">=3.11" dependencies = [ "absl-py>=2.3.0", "einshape", - "jax>=0.9.2", + "jax[cuda12]>=0.9.2", "jaxlib>=0.9.2", "jaxtyping>=0.3", "pydantic>=2.11.0", diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py index d04ccc43..394aa8f0 100644 --- a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/api.py @@ -22,11 +22,22 @@ from tokamax._src.ops.linear_softmax_cross_entropy_loss import base -Implementation: TypeAlias = Literal["mosaic_tpu", "xla"] +Implementation: TypeAlias = Literal["mosaic_gpu", "mosaic_tpu", "triton", "xla"] IMPLEMENTATIONS = dict(xla=base.LinearSoftmaxCrossEntropyLoss()) _DEFAULT_IMPLEMENTATION = ("xla",) +try: + from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_triton # pylint: disable=g-import-not-at-top # pytype: disable=import-error + + IMPLEMENTATIONS["triton"] = ( + pallas_triton.PallasTritonLinearSoftmaxCrossEntropyLoss() + ) + + _DEFAULT_IMPLEMENTATION = ("triton",) + _DEFAULT_IMPLEMENTATION +except ImportError: + pass + try: from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_mosaic_tpu # pylint: disable=g-import-not-at-top # pytype: disable=import-error @@ -38,6 +49,21 @@ except ImportError: pass +try: + from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_mosaic_gpu # pylint: disable=g-import-not-at-top # pytype: disable=import-error + + IMPLEMENTATIONS["mosaic_gpu"] = ( + pallas_mosaic_gpu.PallasMosaicGpuLinearSoftmaxCrossEntropyLoss() + ) + + # mosaic_gpu is NOT added to _DEFAULT_IMPLEMENTATION. Its forward is at XLA + # parity but its backward is ~3× slower (chunked scan over V vs two full-width + # cuBLAS matmuls). The benefit is memory: the (B, V) logit matrix is never + # materialised. Use implementation='mosaic_gpu' explicitly when the logit + # matrix would OOM the device. +except ImportError: + pass + def linear_softmax_cross_entropy_loss( x: Real[Array, "B H"], @@ -72,10 +98,15 @@ def linear_softmax_cross_entropy_loss( precision: The precision used for jax.lax.dot_general for the linear projection and gradient calculation. implementation: By default "None" will be used to pick the best available - backend. Can be set to "xla" or "mosaic_tpu" explicitly. The "mosaic_tpu" - implementation is memory efficient and has almost 0 additional buffer - overhead while the "xla" implementation needs to materialize the full - logits + backend. Can be set to "xla", "mosaic_tpu", "triton", or "mosaic_gpu" + explicitly. The default selection order is mosaic_tpu → triton → xla, + with each backend skipped if unavailable on the current device. + "mosaic_gpu" is available on H100+ (SM90) but is not in the default + chain: its forward is at XLA parity but its backward is ~3× slower due + to chunked-scan accumulation. Use implementation='mosaic_gpu' explicitly + when the (B, V) logit matrix would OOM the device — that is the intended + use case. "mosaic_tpu" and "triton" are memory-efficient and avoid + materialising the full logit matrix. Returns: The Cross-Entropy loss @@ -91,21 +122,16 @@ def linear_softmax_cross_entropy_loss( "Customization of precision is currently not supported." ) - if implementation is not None: - if implementation in IMPLEMENTATIONS: - loss = IMPLEMENTATIONS[implementation]( - x, - labels, - weights, - reduction=reduction, - ) - return loss - else: - raise ValueError(f"Unsupported implementation: {implementation}") + if implementation is None: + implementation = _DEFAULT_IMPLEMENTATION + + if not isinstance(implementation, (tuple, list)): + implementation = (implementation,) - # Find out the best impelmentation based on the hardware. errors = [] - for impl in IMPLEMENTATIONS: + for impl in implementation: + if impl not in IMPLEMENTATIONS: + raise ValueError(f"Unsupported implementation: {impl}") try: loss = IMPLEMENTATIONS[impl]( x, @@ -115,8 +141,6 @@ def linear_softmax_cross_entropy_loss( ) return loss except NotImplementedError as e: - if len(implementation) == 1: - raise errors.append(e) raise ExceptionGroup("all implementations failed", errors) diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu.py new file mode 100644 index 00000000..ed41ce14 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu.py @@ -0,0 +1,187 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pallas-Mosaic-GPU Op implementation of linear softmax cross-entropy loss. + +Forward pass: SM90 WGMMA + TMA kernel (H100+). +Backward pass: chunked scan over V using cuBLAS GEMMs (no atomics, near-XLA speed). +""" + +from dataclasses import dataclass +from typing import ClassVar, Literal + +import jax +import jax.numpy as jnp +from jax.extend import backend +from jaxtyping import Array, Integer, Real, Scalar +from tokamax._src import gpu_utils +from tokamax._src.ops import op +from tokamax._src.ops.linear_softmax_cross_entropy_loss import base +from tokamax._src.ops.linear_softmax_cross_entropy_loss import ( + pallas_mosaic_gpu_common as common, +) +import tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_mosaic_gpu_kernel_sm90 as kernel_sm90 +from typing_extensions import override + + +Config = common.Config +Key = common.Key + + +def linear_softmax_cross_entropy_loss_bwd_chunked_scan( + dout, + lse, + x, + labels, + w, + *, + reduction, + chunk_size=4096, +): + """Chunked-scan backward: padded chunks for full cuBLAS utilisation. + + Uses chunk_size-wide GEMMs throughout — the last chunk is zero-padded and + masked so padded positions contribute nothing to either gradient. This gives + square GEMMs for any vocab size (including irregular sizes like V=128256). + Never materialises the full (B, V) logit matrix. + """ + b_dim, h_dim = x.shape + v_dim = w.shape[1] + + x_f32 = x.astype(jnp.float32) + w_f32 = w.astype(jnp.float32) + lse_f32 = lse.astype(jnp.float32) + scale = ( + dout.astype(jnp.float32) / b_dim + if reduction == "mean" + else dout.astype(jnp.float32) + ) + + num_chunks = (v_dim + chunk_size - 1) // chunk_size + v_padded = num_chunks * chunk_size + + # Pad w to v_padded (last chunk may be partial; extra cols are zero). + w_padded = jnp.pad(w_f32, ((0, 0), (0, v_padded - v_dim))) # (H, v_padded) + # Reshape into (num_chunks, H, chunk_size) for scan. + w_chunks = w_padded.reshape(h_dim, num_chunks, chunk_size).transpose(1, 0, 2) + + def scan_fn(x_grad_carry, args): + chunk_idx, w_chunk = args # w_chunk: (H, chunk_size) + v_start = chunk_idx * chunk_size + logit_chunk = x_f32 @ w_chunk # (B, chunk_size) + softmax_chunk = jnp.exp(logit_chunk - lse_f32[:, None]) + col_idx = jnp.arange(chunk_size) + v_start + one_hot_chunk = (col_idx[None, :] == labels[:, None]).astype(jnp.float32) + # Zero out padded positions so they don't contribute to either gradient. + valid = (col_idx < v_dim).astype(jnp.float32)[None, :] + s_chunk = scale * (softmax_chunk - one_hot_chunk) * valid + x_grad_carry = x_grad_carry + s_chunk @ w_chunk.T # (B, H) + w_grad_chunk = x_f32.T @ s_chunk # (H, chunk_size) + return x_grad_carry, w_grad_chunk + + x_grad, w_grad_chunks = jax.lax.scan( + scan_fn, + jnp.zeros((b_dim, h_dim), dtype=jnp.float32), + (jnp.arange(num_chunks), w_chunks), + ) + # w_grad_chunks: (num_chunks, H, chunk_size) → (H, v_padded) → (H, V) + w_grad = w_grad_chunks.transpose(1, 0, 2).reshape(h_dim, v_padded)[:, :v_dim] + return x_grad, w_grad + + +def _mosaic_vjp( + residuals: base.Residuals, + out: jax.Array, + dout: jax.Array, + x: jax.Array, + labels: jax.Array, + w: jax.Array, + *, + reduction: str = "sum", + return_residuals: bool = False, +): + """Mosaic GPU backward: chunked scan over V (no atomics, cuBLAS per chunk).""" + del out, return_residuals + (lse,) = residuals + x_grad, w_grad = linear_softmax_cross_entropy_loss_bwd_chunked_scan( + dout, + lse, + x, + labels, + w, + reduction=reduction, + ) + labels_grad = jnp.zeros_like(labels) + return (x_grad, labels_grad, w_grad) + + +@dataclass(frozen=True, kw_only=True) +class PallasMosaicGpuLinearSoftmaxCrossEntropyLoss( + base.LinearSoftmaxCrossEntropyLoss[Config] +): + """Pallas/Mosaic-GPU SM90 forward + backward for linear softmax CE loss. + + Forward: SM90 WGMMA + TMA kernel (H100+). + Backward: chunked scan over V using cuBLAS GEMMs (no atomics, no WGMMA). + """ + + config_cls: ClassVar[type[Config]] = Config + + def __post_init__(self): + object.__setattr__(self, "vjp", _mosaic_vjp) + + @override + def _fwd( + self, + x: Real[Array, "B H"], + labels: Integer[Array, "B"], + w: Real[Array, "H V"], + *, + reduction: Literal["sum", "mean"] = "sum", + config: Config, + return_residuals: bool, + ) -> tuple[jax.Array, base.Residuals]: + device_kind = backend.get_default_device().device_kind.lower() + if not (gpu_utils.is_sm90() or gpu_utils.is_sm100()): + raise NotImplementedError( + f"Mosaic GPU kernel requires SM90 or SM100; got {device_kind!r}." + ) + + loss, lse = kernel_sm90.linear_softmax_cross_entropy_loss_fwd_pallas_mosaic_gpu_sm90( + x, + labels, + w, + tile_m=config.tile_m, + tile_n=config.tile_n, + tile_k=config.tile_k, + num_stages=config.num_stages, + reduction=reduction, + ) + return loss, (lse,) + + @override + def _get_heuristics_config(self, ba: op.BoundArguments) -> Config: + return common.get_heuristics_config(ba.arguments["x"], ba.arguments["w"]) + + @override + def _get_autotuning_configs(self, ba: op.BoundArguments) -> set[Config]: + return common.get_autotuning_configs(ba.arguments["x"], ba.arguments["w"]) + + @override + def _get_autotuning_cache_key(self, ba: op.BoundArguments) -> Key: + return common.get_key(**ba.arguments) + + @override + def supported_on(self, device: jax.Device) -> bool: + return gpu_utils.has_mosaic_gpu_support(device) diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_common.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_common.py new file mode 100644 index 00000000..27a28b41 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_common.py @@ -0,0 +1,92 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common definitions for Pallas-Mosaic-GPU linear softmax cross-entropy loss.""" + +from typing import Annotated, Any, TypeAlias + +import immutabledict +import jax +import jax.numpy as jnp +import pydantic +from tokamax._src import pydantic as pydantic_lib + + +@pydantic.dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class Config: + """Tile-size configuration for the Pallas/Mosaic-GPU kernel. + + The matmul is x[B, H] @ w[H, V] tiled as (B=M, H=K, V=N). + + Attributes: + tile_m: Tile size over the batch/token (B) dimension. Each CTA handles + 2 * tile_m rows (two warp groups each covering tile_m rows). B must be + divisible by 2 * tile_m. + tile_n: Tile size over the vocabulary (V) dimension. V must be divisible + by tile_n. + tile_k: Tile size for the inner hidden (H/K) matmul loop. H must be + divisible by tile_k. + num_stages: Maximum number of concurrent pipeline stages for async + TMA prefetch. + """ + + tile_m: Annotated[int, pydantic.Field(ge=128, multiple_of=64)] = 128 + tile_n: Annotated[int, pydantic.Field(ge=64, multiple_of=64)] = 128 + tile_k: Annotated[int, pydantic.Field(ge=16, multiple_of=16)] = 64 + num_stages: pydantic_lib.PowerOfTwo = 4 + + +Key: TypeAlias = immutabledict.immutabledict[str, Any] + + +def get_heuristics_config(x: jax.Array, w: jax.Array) -> Config: + """Returns a reasonable default config for H100 (sm90).""" + del x, w # shapes don't change the default for sm90 + return Config(tile_m=128, tile_n=128, tile_k=64, num_stages=4) + + +def get_autotuning_configs(x: jax.Array, w: jax.Array) -> set[Config]: + """Returns a bounded set of configs to try during autotuning.""" + b_dim, h_dim = x.shape + v_dim = w.shape[1] + + tile_ms = [t for t in (128,) if b_dim % (2 * t) == 0] + tile_ns = [t for t in (64, 128, 256) if v_dim % t == 0] + tile_ks = [t for t in (32, 64, 128) if h_dim % t == 0] + num_stages_opts = [2, 4] + + configs: set[Config] = set() + for tm in tile_ms: + for tn in tile_ns: + for tk in tile_ks: + for ns in num_stages_opts: + configs.add(Config(tile_m=tm, tile_n=tn, tile_k=tk, num_stages=ns)) + return configs + + +def get_key( + x: jax.Array, + labels: jax.Array, + w: jax.Array, + *, + reduction: str, + **_kwargs, +) -> Key: + """Returns the autotuning cache lookup key for the given arguments.""" + return immutabledict.immutabledict( + x=jax.ShapeDtypeStruct(x.shape, x.dtype), + labels=jax.ShapeDtypeStruct(labels.shape, labels.dtype), + w=jax.ShapeDtypeStruct(w.shape, w.dtype), + reduction=reduction, + ) diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90.py new file mode 100644 index 00000000..3d8e4331 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90.py @@ -0,0 +1,294 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pallas-Mosaic-GPU SM90 forward+backward kernels for linear softmax CE loss. + +Algorithm (forward): tiles (B, V) with an inner H pipeline, so the +(b_tile, v_tile) logit matrix never appears in HBM. Two warp groups (wg=0,1) +each handle tile_m rows of the 2*tile_m CTA tile; WGMMA + TMA pipelines +compute the matmul x[b_tile,:] @ w[:,v_tile] and the epilogue reduces to +per-token logsumexp. The correct-class logit is computed outside the kernel as +a cheap O(B*H) XLA einsum (gather + dot). + +Algorithm (backward): implemented in pallas_mosaic_gpu.py as a jax.lax.scan +over padded vocabulary chunks, issuing cuBLAS GEMMs per chunk (not WGMMA). +The in-kernel backward (linear_softmax_cross_entropy_loss_bwd_pallas_mosaic_gpu_sm90) +exists and is tested, but is not wired into the Op — it was superseded by the +chunked-scan approach which avoids atomic_add serialisation across CTAs. +""" + +import functools +from typing import Literal + +import jax +from jax import lax +from jax.experimental import pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +from jax.extend import backend +import jax.numpy as jnp +from jaxtyping import Array, Integer, Real, Scalar + +_WGMMA = plgpu.Layout.WGMMA +_WGMMA_ROW = plgpu.Layout.WGMMA.reduce(1) + + +def _validate_inputs( + x: jax.Array, + labels: jax.Array, + w: jax.Array, + tile_m: int, + tile_k: int, + tile_n: int, +) -> None: + """Validates inputs and tile-size constraints.""" + b_dim, h_dim = x.shape + v_dim = w.shape[1] + if b_dim % (2 * tile_m) != 0: + raise ValueError( + f"Batch dimension B={b_dim} must be divisible by" + f" 2 * tile_m={2 * tile_m}." + ) + if h_dim % tile_k != 0: + raise ValueError( + f"Hidden dimension H={h_dim} must be divisible by tile_k={tile_k}." + ) + if v_dim % tile_n != 0: + raise ValueError( + f"Vocab dimension V={v_dim} must be divisible by tile_n={tile_n}." + ) + if w.shape[0] != h_dim: + raise ValueError( + f"w hidden dim {w.shape[0]} must match x hidden dim {h_dim}." + ) + if labels.shape[0] != b_dim: + raise ValueError( + f"labels batch size {labels.shape[0]} must match x batch size {b_dim}." + ) + + +def linear_softmax_cross_entropy_loss_fwd_pallas_mosaic_gpu_sm90( + x: Real[Array, "B H"], + labels: Integer[Array, "B"], + w: Real[Array, "H V"], + *, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 64, + num_stages: int = 4, + reduction: Literal["sum", "mean"] = "sum", +) -> tuple[Real[Scalar, ""], Real[Array, "B"]]: + """Forward pass for linear softmax cross-entropy loss via Pallas/Mosaic-GPU. + + Uses WGMMA + TMA pipelining on SM90 (H100). Two warp groups each handle + tile_m rows of the current (b_cta, v) tile, accumulating x @ w across the + H dimension before computing per-token logsumexp and correct-class logit. + + Args: + x: Hidden states, shape (B, H). + labels: Integer token indices, shape (B,). + w: LM head weight matrix, shape (H, V). + tile_m: Tile size over B. Each CTA uses 2 * tile_m rows; B must be + divisible by 2 * tile_m. + tile_n: Tile size over V. V must be divisible by tile_n. + tile_k: Tile size for the H contraction loop. H must be divisible by + tile_k. + num_stages: TMA pipeline depth. + reduction: "sum" or "mean" over tokens. + + Returns: + (loss, lse) where lse is the per-token log-sum-exp, shape (B,). + """ + _validate_inputs(x, labels, w, tile_m, tile_k, tile_n) + + # Mosaic GPU wgmma operates in bfloat16 with float32 accumulation. Downcast + # float32 inputs to bfloat16 to halve SMEM usage and use the faster bf16 + # wgmma path (same approach as the attention sm90 kernel). + if x.dtype != jnp.bfloat16: + x = x.astype(jnp.bfloat16) + if w.dtype != jnp.bfloat16: + w = w.astype(jnp.bfloat16) + + b_dim, h_dim = x.shape + v_dim = w.shape[1] + dtype = x.dtype # bfloat16 + elem_bits = jnp.finfo(dtype).bits + + cta_tile_m = 2 * tile_m # two warp groups each covering tile_m rows + b_cta_iters = b_dim // cta_tile_m + v_iters = v_dim // tile_n + k_iters = h_dim // tile_k + + # Swizzle for lhs (x tiles: last dim = tile_k) and rhs (w tiles: last dim = tile_n). + # Rule: swizzle = find_swizzle(last_dim * elem_bits) — see attention common. + lhs_swizzle = plgpu.find_swizzle(tile_k * elem_bits) + lhs_swizzle_elems = 8 * lhs_swizzle // elem_bits + lhs_transforms = ( + plgpu.TilingTransform((8, lhs_swizzle_elems)), + plgpu.SwizzleTransform(lhs_swizzle), + ) + + rhs_swizzle = plgpu.find_swizzle(tile_n * elem_bits) + rhs_swizzle_elems = 8 * rhs_swizzle // elem_bits + rhs_transforms = ( + plgpu.TilingTransform((8, rhs_swizzle_elems)), + plgpu.SwizzleTransform(rhs_swizzle), + ) + + def kernel( + x_gmem, + w_gmem, + tile_lse_gmem, + lse_smem, + ): + """Persistent kernel body. + + Args: + x_gmem: Input activations, shape (B, H). + w_gmem: Weight matrix, shape (H, V). + tile_lse_gmem: Output per-tile logsumexp, shape (v_iters, B). + lse_smem: Scratch SMEM for lse staging, shape (2, tile_m). + """ + + def get_pipeline(pipeline_body, compute_context): + return plgpu.emit_pipeline_warp_specialized( + pipeline_body, + grid=(k_iters,), + memory_registers=40, + in_specs=[ + plgpu.BlockSpec( + (cta_tile_m, tile_k), + lambda k: (0, k), + transforms=lhs_transforms, + memory_space=plgpu.SMEM, + ), + plgpu.BlockSpec( + (tile_k, tile_n), + lambda k: (k, 0), + transforms=rhs_transforms, + memory_space=plgpu.SMEM, + ), + ], + wg_axis="wg", + num_compute_wgs=2, + max_concurrent_steps=num_stages, + compute_context=compute_context, + ) + + ignore = lambda *_, **__: None + + @functools.partial( + pl.run_scoped, + pipeline_allocs=get_pipeline(ignore, ignore).get_allocations( + x_gmem, w_gmem + ), + collective_axes="wg", + ) + def _pipeline_scope(pipeline_allocs): + wg_idx = lax.axis_index("wg") + + @plgpu.nd_loop((b_cta_iters * v_iters,), collective_axes="cluster_grid") + def _bv_loop(loop_info): + (lin_idx,) = loop_info.index + b_cta_idx = lin_idx // v_iters + v_idx = lin_idx % v_iters + + b_cta_start = b_cta_idx * cta_tile_m + v_start = v_idx * tile_n + + # Each wg handles its own tile_m-row slice of the cta_tile_m block. + wg_b_start = b_cta_start + wg_idx * tile_m + b_wg_slice = pl.ds(wg_b_start, tile_m) + + def compute_context(eval_pipeline): + + @functools.partial( + pl.run_scoped, + acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32), + ) + def _acc_scope(acc_ref): + eval_pipeline(acc_ref) + acc = acc_ref[...].astype(jnp.float32) # (tile_m, tile_n) WGMMA + + # Per-token logsumexp over this V tile. + # - No keepdims: (tile_m, 1) violates WGMMA tile divisibility. + # - jax.nn.logsumexp is off-limits: calls is_finite internally. + # - Use lax.broadcast_in_dim to expand back to (tile_m, tile_n). + amax = jnp.max(acc, axis=-1) # (tile_m,) WGMMA_ROW + amax_bcast = lax.broadcast_in_dim(amax, acc.shape, [0]) + tile_lse_vals = amax + jnp.log( + jnp.sum(jnp.exp(acc - amax_bcast), axis=-1) + ) # (tile_m,) WGMMA_ROW + + # Stage through SMEM then TMA-store to GMEM. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + lse_smem[wg_idx] = tile_lse_vals + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + lse_smem.at[wg_idx], + tile_lse_gmem.at[v_idx, b_wg_slice], + ) + + def mma_body(_, x_smem, w_smem, acc_ref): + wg_m_slice = pl.ds(wg_idx * tile_m, tile_m) + # w is (K, N) in SMEM — no transpose needed (cf. v_smem in attention). + plgpu.wgmma(acc_ref, x_smem.at[wg_m_slice], w_smem) + plgpu.wgmma_wait(0) + return acc_ref + + get_pipeline(mma_body, compute_context)( + x_gmem.at[pl.ds(b_cta_start, cta_tile_m), :], + w_gmem.at[:, pl.ds(v_start, tile_n)], + allocations=pipeline_allocs, + ) + + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + num_sms = backend.get_default_device().core_count + scratch_shapes = [ + plgpu.SMEM((2, tile_m), jnp.float32), # lse staging, one row per wg + ] + + f = plgpu.kernel( + kernel, + out_shape=[ + jax.ShapeDtypeStruct((v_iters, b_dim), jnp.float32), + ], + grid=(num_sms,), + grid_names=("cluster_grid",), + cluster=(1,), + cluster_names=("cluster",), + num_threads=3, + thread_name="wg", + scratch_shapes=scratch_shapes, + ) + + (tile_lse,) = f(x, w) + + # Combine across V tiles; tile_lse is (v_iters, B), reduce over v_iters. + lse = jax.nn.logsumexp(tile_lse, axis=0) # (B,) + + # Correct-class logit: O(B*H) XLA gather+dot, much cheaper than the kernel. + # Using float32 throughout for consistency with the fp32 kernel accumulation. + x_f32 = x.astype(jnp.float32) + w_f32 = w.astype(jnp.float32) + correct_logit = jnp.einsum("bh,hb->b", x_f32, w_f32[:, labels]) # (B,) + per_token_loss = lse - correct_logit + + if reduction == "sum": + loss = jnp.sum(per_token_loss) + else: + loss = jnp.mean(per_token_loss) + + return loss.astype(jnp.float32), lse + diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90_test.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90_test.py new file mode 100644 index 00000000..0a287ce3 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_kernel_sm90_test.py @@ -0,0 +1,108 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SM90 Pallas/Mosaic-GPU forward kernel function. + +Covers a range of tile configurations representative of the autotuning search +space (tile_n in {64, 128, 256}, tile_k in {64, 128}, num_stages in {2, 4}). +This ensures that configurations beyond the default (128/128/64) are correct, +which is important for autotuning to produce meaningful results. +""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from tokamax._src import gpu_utils +from tokamax._src.ops.linear_softmax_cross_entropy_loss import ( + pallas_mosaic_gpu_kernel_sm90 as kernel_sm90, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss import reference +from tokamax._src.ops.linear_softmax_cross_entropy_loss import test_utils + + +# B=512 is divisible by 2*tile_m=256 for all tile_m=128 configs. +# V=512 is divisible by tile_n in {64, 128, 256}. +# H=256 is divisible by tile_k in {64, 128}. +_B, _H, _V = 512, 256, 512 + + +def _skip_if_not_sm90(test_case): + if jax.default_backend() != "gpu": + test_case.skipTest("GPU-only test.") + if not gpu_utils.has_mosaic_gpu_support(): + test_case.skipTest("Mosaic GPU requires SM90+ (H100 or newer).") + + +class PallasMosaicGpuSm90FwdKernelTest(parameterized.TestCase): + """Direct tests of the SM90 forward kernel with various tile configs. + + The forward kernel has no s_smem, so it supports tile_n=256 and + tile_k=128 at num_stages=2 (193 KB and 129 KB respectively). + """ + + def setUp(self): + super().setUp() + _skip_if_not_sm90(self) + + @parameterized.named_parameters( + dict( + testcase_name="default", + tile_m=128, tile_n=128, tile_k=64, num_stages=4, + ), + dict( + testcase_name="small_tile_n", + tile_m=128, tile_n=64, tile_k=64, num_stages=2, + ), + dict( + testcase_name="large_tile_n", + tile_m=128, tile_n=256, tile_k=64, num_stages=2, + ), + dict( + testcase_name="large_tile_k", + tile_m=128, tile_n=128, tile_k=128, num_stages=2, + ), + ) + def test_forward_matches_reference( + self, tile_m, tile_n, tile_k, num_stages, + ): + x, labels, w = test_utils.generate_random_data( + jax.random.key(0), _B, _H, _V + ) + + ref_loss, ref_lse = reference.linear_softmax_cross_entropy_loss_fwd_reference( + x, labels, w, reduction="sum" + ) + kernel_loss, kernel_lse = kernel_sm90.linear_softmax_cross_entropy_loss_fwd_pallas_mosaic_gpu_sm90( + x, labels, w, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + num_stages=num_stages, reduction="sum", + ) + + # bf16 WGMMA precision: the forward loss and per-token LSE are insensitive + # to the bf16 quantization (logsumexp is well-conditioned), so 2e-2 holds. + self.assertTrue( + jnp.allclose(ref_loss, kernel_loss.astype(jnp.float32), atol=2e-2, rtol=2e-2), + msg=f"loss: ref={float(ref_loss):.6f} kernel={float(kernel_loss):.6f}", + ) + self.assertTrue( + jnp.allclose(ref_lse, kernel_lse.astype(jnp.float32), atol=2e-2, rtol=2e-2), + msg=f"lse max_diff={float(jnp.max(jnp.abs(ref_lse - kernel_lse))):.6f}", + ) + + + +if __name__ == "__main__": + absltest.main() diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_test.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_test.py new file mode 100644 index 00000000..09465f40 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_mosaic_gpu_test.py @@ -0,0 +1,169 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""End-to-end tests for the Pallas/Mosaic-GPU linear softmax cross-entropy loss Op.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from tokamax._src import gpu_utils +from tokamax._src.ops.linear_softmax_cross_entropy_loss.base import ( + LinearSoftmaxCrossEntropyLoss, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_mosaic_gpu import ( + PallasMosaicGpuLinearSoftmaxCrossEntropyLoss, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_mosaic_gpu_common import ( + Config, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.test_utils import ( + generate_random_data, +) + + +class PallasMosaicGpuLceOpTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if jax.default_backend() != "gpu": + self.skipTest("GPU-only test.") + if not gpu_utils.has_mosaic_gpu_support(): + self.skipTest("Mosaic GPU requires SM90+ (H100 or newer).") + + @parameterized.named_parameters( + dict( + testcase_name="small_sum", + b_dim=256, + h_dim=128, + v_dim=256, + reduction="sum", + ), + dict( + testcase_name="small_mean", + b_dim=256, + h_dim=128, + v_dim=256, + reduction="mean", + ), + dict( + testcase_name="medium_sum", + b_dim=256, + h_dim=256, + v_dim=512, + reduction="sum", + ), + dict( + testcase_name="medium_mean", + b_dim=256, + h_dim=256, + v_dim=512, + reduction="mean", + ), + dict( + testcase_name="bfloat16", + b_dim=256, + h_dim=128, + v_dim=256, + reduction="sum", + dtype=jnp.bfloat16, + ), + ) + def test_value_and_grad_matches_reference( + self, + b_dim, + h_dim, + v_dim, + reduction, + dtype=jnp.float32, + ): + x, labels, w = generate_random_data( + jax.random.key(42), b_dim, h_dim, v_dim, dtype=dtype + ) + # tile_m=128 so 2*tile_m=256 divides b_dim=256. + config = Config(tile_m=128, tile_n=128, tile_k=64, num_stages=4) + + mosaic_op = PallasMosaicGpuLinearSoftmaxCrossEntropyLoss(config=config) + ref_op = LinearSoftmaxCrossEntropyLoss() + + # For bfloat16 compare against float32-upcast reference (kernel accumulates + # in float32 internally). + x_ref = x.astype(jnp.float32) if dtype == jnp.bfloat16 else x + w_ref = w.astype(jnp.float32) if dtype == jnp.bfloat16 else w + + mosaic_loss, (mosaic_x_grad, mosaic_w_grad) = jax.value_and_grad( + mosaic_op, argnums=(0, 2) + )(x, labels, w, reduction=reduction) + + ref_loss, (ref_x_grad, ref_w_grad) = jax.value_and_grad( + ref_op, argnums=(0, 2) + )(x_ref, labels, w_ref, reduction=reduction) + + # Tolerance notes: + # + # bfloat16 inputs: the kernel internally keeps bf16 inputs and the + # reference is run on float32-upcast values, so errors are modest. + # + # float32 inputs with "mean" reduction: scale = dout/B is tiny, so + # gradient magnitudes are O(1/B) and element-wise absolute errors + # are proportionally small. + # + # float32 inputs with "sum" reduction: the SM90 forward kernel down-casts + # float32 inputs to bf16 for WGMMA (hardware requirement), which makes the + # stored lse slightly imprecise. The chunked-scan backward uses float32 + # arithmetic on this bf16-derived lse, which can produce errors up to ~0.35 + # per gradient element vs a fully float32 reference. We use atol=0.40 here + # (with headroom above the empirical worst-case of ~0.35). The loss scalar + # has much smaller absolute values and is checked at the tighter 2e-2 level. + if dtype == jnp.bfloat16: + atol_grad, rtol_grad = 5e-2, 5e-2 + elif reduction == "sum": + atol_grad, rtol_grad = 0.40, 0.05 + else: # float32, mean + atol_grad, rtol_grad = 2e-2, 2e-2 + atol_loss = 2e-2 + rtol_loss = 2e-2 + + self.assertTrue( + jnp.allclose( + ref_loss.astype(jnp.float32), + mosaic_loss.astype(jnp.float32), + atol=atol_loss, + rtol=rtol_loss, + ), + msg=f"loss: ref={float(ref_loss):.6f} mosaic={float(mosaic_loss):.6f}", + ) + self.assertTrue( + jnp.allclose( + ref_x_grad.astype(jnp.float32), + mosaic_x_grad.astype(jnp.float32), + atol=atol_grad, + rtol=rtol_grad, + ), + msg=f"x_grad max_diff={float(jnp.max(jnp.abs(ref_x_grad.astype(jnp.float32) - mosaic_x_grad.astype(jnp.float32)))):.6f}", + ) + self.assertTrue( + jnp.allclose( + ref_w_grad.astype(jnp.float32), + mosaic_w_grad.astype(jnp.float32), + atol=atol_grad, + rtol=rtol_grad, + ), + msg=f"w_grad max_diff={float(jnp.max(jnp.abs(ref_w_grad.astype(jnp.float32) - mosaic_w_grad.astype(jnp.float32)))):.6f}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton.py new file mode 100644 index 00000000..93487a39 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton.py @@ -0,0 +1,181 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pallas-Triton Op implementation of linear softmax cross-entropy loss.""" + +from dataclasses import dataclass +from typing import ClassVar, Literal + +import jax +import jax.numpy as jnp +from jaxtyping import Array, Integer, Real +from tokamax._src import gpu_utils +from tokamax._src.ops import op +from tokamax._src.ops.linear_softmax_cross_entropy_loss import base +from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_triton_config +import tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_triton_kernel as kernel +from typing_extensions import override + + +Config = pallas_triton_config.Config + +def linear_softmax_cross_entropy_loss_bwd_chunked_scan( + dout, lse, x, labels, w, + *, reduction, chunk_size=4096, +): + """Chunked-scan backward: padded chunks for full cuBLAS utilisation. + + Uses chunk_size-wide GEMMs throughout — the last chunk is zero-padded and + masked so padded positions contribute nothing to either gradient. This gives + square GEMMs for any vocab size (including irregular sizes like V=128256). + Never materialises the full (B, V) logit matrix. + """ + b_dim, h_dim = x.shape + v_dim = w.shape[1] + + x_f32 = x.astype(jnp.float32) + w_f32 = w.astype(jnp.float32) + lse_f32 = lse.astype(jnp.float32) + scale = ( + dout.astype(jnp.float32) / b_dim + if reduction == "mean" + else dout.astype(jnp.float32) + ) + + num_chunks = (v_dim + chunk_size - 1) // chunk_size + v_padded = num_chunks * chunk_size + + # Pad w to v_padded (last chunk may be partial; extra cols are zero). + w_padded = jnp.pad(w_f32, ((0, 0), (0, v_padded - v_dim))) # (H, v_padded) + # Reshape into (num_chunks, H, chunk_size) for scan. + w_chunks = w_padded.reshape(h_dim, num_chunks, chunk_size).transpose(1, 0, 2) + + def scan_fn(x_grad_carry, args): + chunk_idx, w_chunk = args # w_chunk: (H, chunk_size) + v_start = chunk_idx * chunk_size + logit_chunk = x_f32 @ w_chunk # (B, chunk_size) + softmax_chunk = jnp.exp(logit_chunk - lse_f32[:, None]) + col_idx = jnp.arange(chunk_size) + v_start + one_hot_chunk = (col_idx[None, :] == labels[:, None]).astype(jnp.float32) + # Zero out padded positions so they don't contribute to either gradient. + valid = (col_idx < v_dim).astype(jnp.float32)[None, :] + s_chunk = scale * (softmax_chunk - one_hot_chunk) * valid + x_grad_carry = x_grad_carry + s_chunk @ w_chunk.T # (B, H) + w_grad_chunk = x_f32.T @ s_chunk # (H, chunk_size) + return x_grad_carry, w_grad_chunk + + x_grad, w_grad_chunks = jax.lax.scan( + scan_fn, + jnp.zeros((b_dim, h_dim), dtype=jnp.float32), + (jnp.arange(num_chunks), w_chunks), + ) + # w_grad_chunks: (num_chunks, H, chunk_size) → (H, v_padded) → (H, V) + w_grad = w_grad_chunks.transpose(1, 0, 2).reshape(h_dim, v_padded)[:, :v_dim] + return x_grad, w_grad + + +@dataclass(frozen=True, kw_only=True) +class PallasTritonLinearSoftmaxCrossEntropyLoss( + base.LinearSoftmaxCrossEntropyLoss[Config] +): + """Pallas/Triton GPU implementation of linear softmax cross-entropy loss.""" + + config_cls: ClassVar[type[Config]] = Config + + def __post_init__(self): + object.__setattr__( + self, + "vjp", + PallasTritonLinearSoftmaxCrossEntropyLossVjp(config=self.config), + ) + + @override + def _fwd( + self, + x: Real[Array, "B H"], + labels: Integer[Array, "B"], + w: Real[Array, "H V"], + *, + reduction: Literal["sum", "mean"] = "sum", + config: Config, + return_residuals: bool, + ) -> tuple[jax.Array, base.Residuals]: + loss, lse = kernel.linear_softmax_cross_entropy_loss_fwd_pallas_triton( + x, + labels, + w, + b_block_size=config.b_block_size, + h_block_size=config.h_block_size, + v_block_size=config.v_block_size, + reduction=reduction, + num_warps=config.num_warps, + ) + return loss, (lse,) + + @override + def _get_heuristics_config(self, ba: op.BoundArguments) -> Config: + return pallas_triton_config.get_heuristics_config( + ba.arguments["x"], ba.arguments["w"] + ) + + @override + def supported_on(self, device: jax.Device) -> bool: + return gpu_utils.has_triton_support(device) + + +@dataclass(frozen=True, kw_only=True) +class PallasTritonLinearSoftmaxCrossEntropyLossVjp( + base.LinearSoftmaxCrossEntropyLossVjp[Config] +): + """Pallas/Triton GPU VJP for linear softmax cross-entropy loss.""" + + config_cls: ClassVar[type[Config]] = Config + + @override + def _fwd( + self, + residuals: base.Residuals, + out: Real[Array, ""], + dout: Real[Array, ""], + x: Real[Array, "B H"], + labels: Integer[Array, "B"], + w: Real[Array, "H V"], + *, + reduction: Literal["sum", "mean"] = "sum", + config: Config, + return_residuals: bool, + ) -> tuple[tuple[jax.Array, jax.Array, jax.Array], None]: + del out + (lse,) = residuals + + x_grad, w_grad = linear_softmax_cross_entropy_loss_bwd_chunked_scan( + dout, + lse, + x, + labels, + w, + reduction=reduction, + ) + labels_grad = jnp.zeros_like(labels) + return (x_grad, labels_grad, w_grad), None + + @override + def _get_heuristics_config(self, ba: op.BoundArguments) -> Config: + return pallas_triton_config.get_heuristics_config( + ba.arguments["x"], ba.arguments["w"] + ) + + @override + def supported_on(self, device: jax.Device) -> bool: + return gpu_utils.has_triton_support(device) diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_config.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_config.py new file mode 100644 index 00000000..5d93a3a0 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_config.py @@ -0,0 +1,110 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pallas-Triton linear softmax cross-entropy loss configuration.""" + +from typing import Annotated + +import jax +import pydantic +from tokamax._src import pydantic as pydantic_lib + + +@pydantic.dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class Config: + """Tile-size configuration for the Pallas/Triton GPU kernel. + + All block sizes must evenly divide the corresponding tensor dimension. + + Attributes: + b_block_size: Tile size over the batch/token (B) dimension. + h_block_size: Tile size for the inner hidden (H) matmul loop. Each + iteration loads a (b_block_size, h_block_size) slice of x and a + (h_block_size, v_block_size) slice of w; total HBM data volume is the + same regardless of this value. It controls register pressure and the + matmul tile shape presented to tensor cores. + v_block_size: Tile size over the vocabulary (V) dimension. + num_warps: Number of Triton warps per program. + """ + + b_block_size: Annotated[int, pydantic.Field(ge=16, multiple_of=16)] = 32 + h_block_size: Annotated[int, pydantic.Field(ge=16, multiple_of=16)] = 64 + v_block_size: Annotated[int, pydantic.Field(ge=16, multiple_of=16)] = 128 + num_warps: pydantic_lib.PowerOfTwo = 4 + + +def get_heuristics_config( + x: jax.Array, + w: jax.Array, +) -> Config: + """Returns a register-safe config based on the input shapes. + + ## v_block_size (fixed at 128) + + v_block_size=256 crashes the Triton-to-PTX compilation stage in JAX 0.9.2's + bundled Triton: the power-of-2 check in pallas/triton/lowering.py passes + (total tensor size 8192 is a power of 2) but the Triton compiler then crashes + with a C++ exception. The check comment explicitly warns: "the Triton lowering + will fail anyway but it will crash with a C++ exception". The nearest upstream + fix is jax-ml/jax#35654, which guards the same crash for fp64; the fp32/n=256 + case is not yet guarded. Revisit when JAX upgrades its bundled Triton. + + ## Register budget (SM80+, 65536 regs per SM, num_warps=4, 128 threads) + + With v_block=128, per-thread register cost: + accumulator: b_block * v_block / 128 = b_block regs/thread. + w tile: h_block * v_block / 128 = h_block regs/thread. + x tile: b_block * h_block / 128 regs/thread. + total: b_block + h_block + b_block * h_block / 128. + + The 50%-budget constraint (256 regs/thread, allows 2 CTAs/SM) limits + combined (b_block, h_block) choices: + b=128, h=64: 128 + 64 + 64 = 256 regs (50%) ← 2 CTAs/SM OK + b=64, h=128: 64 + 128 + 64 = 256 regs (50%) ← 2 CTAs/SM OK + b=64, h=64: 64 + 64 + 32 = 160 regs (31%) ← safe + b=32, h=128: 32 + 128 + 32 = 192 regs (37%) ← safe + b=128, h=128: 128 + 128 + 128 = 384 regs (75%) ← 1 CTA/SM, avoided + + ## HBM traffic analysis + + HBM reads scale as (all shapes in elements): + x traffic: B * H * (V / v_block) — x is re-read once per v_block tile. + w traffic: H * V * (B / b_block) — w is re-read once per b_block tile. + + At v_block=128: x traffic = B*H*V/128, w traffic = B*H*V/b_block. + Traffic is balanced when b_block = v_block = 128. At b_block=64, w traffic + is 2× x traffic; at b_block=32, 4×. So b_block=128 (when B divisible by 128) + minimises total HBM reads and measurably outperforms b_block=64 (~4% on + LLM-scale shapes, bandwidth-bound regime). + + When b_block=128, h_block is capped at 64 to stay within the 50% budget. + When b_block<=64, h_block=128 (if H divisible by 128) for better tensor-core + tile efficiency; h_block does not affect HBM traffic. + """ + b_dim, h_dim = x.shape + if b_dim % 128 == 0: + b_block_size = 128 + h_block_size = 64 # b=128,h=128 → 75% regs → 1 CTA/SM; cap at 64. + elif b_dim % 64 == 0: + b_block_size = 64 + h_block_size = 128 if h_dim % 128 == 0 else 64 + else: + b_block_size = 32 + h_block_size = 128 if h_dim % 128 == 0 else 64 + return Config( + b_block_size=b_block_size, + h_block_size=h_block_size, + v_block_size=128, + num_warps=4, + ) diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel.py new file mode 100644 index 00000000..8fb0acfc --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel.py @@ -0,0 +1,240 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pallas-Triton forward kernel for Linear Softmax Cross-Entropy Loss.""" + +from functools import partial +from typing import Literal + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import triton as plgpu +import jax.numpy as jnp +from jaxtyping import Array, Integer, Real, Scalar + +from tokamax._src.pallas import block + + +def _validate_inputs( + x: jax.Array, + labels: jax.Array, + w: jax.Array, + b_block_size: int, + h_block_size: int, + v_block_size: int, +) -> None: + """Validates inputs and block-size constraints.""" + b_dim, h_dim = x.shape + v_dim = w.shape[1] + if b_dim % b_block_size != 0: + raise ValueError( + f"Batch dimension B={b_dim} must be divisible by" + f" b_block_size={b_block_size}." + ) + if w.shape[0] != h_dim: + raise ValueError( + f"w hidden dim {w.shape[0]} must match x hidden dim {h_dim}." + ) + if h_dim % h_block_size != 0: + raise ValueError( + f"Hidden dimension H={h_dim} must be divisible by" + f" h_block_size={h_block_size}." + ) + if labels.shape[0] != b_dim: + raise ValueError( + f"labels batch size {labels.shape[0]} must match x batch size {b_dim}." + ) + + +def _lce_fwd_kernel( + x_ref, + labels_ref, + w_ref, + tile_lse_ref, + correct_logit_ref, + *, + b_block_size: int, + h_block_size: int, + num_h_blocks: int, + v_block_size: int, + v_dim: int, +): + """Per-(b_block, v_block) tile: fused matmul + logsumexp + correct-logit. + + Each program computes one tile of the logit matrix x[b_block, :] @ w[:, v_block] + entirely in registers, never writing logits to HBM. It outputs: + - tile_lse: per-token logsumexp over this V chunk (B, num_v_blocks) + - correct_logit: per-token correct-class logit from this V chunk (B, num_v_blocks) + + These are combined outside the kernel: lse = logsumexp(tile_lse, axis=-1) and + correct_logit = sum(correct_logit, axis=-1), giving the final per-token loss. + + w may be zero-padded to the next multiple of v_block_size. Padded columns are + masked to -inf before the logsumexp so they contribute nothing. correct_logit + uses the unmasked xw_tile; one_hot is 0 for padded columns (labels < v_dim). + """ + v_idx = pl.program_id(1) + v_start = v_idx * v_block_size + + # Accumulate x[b_block, :] @ w[:, v_block] across H blocks in float32. + def h_body(h_idx, acc): + x_tile = x_ref.at[:, block.ds(h_idx, h_block_size)].load( + bounds_check=(False, True) + ) + w_tile = w_ref.at[block.ds(h_idx, h_block_size), :].load( + bounds_check=(True, False) + ) + return acc + pl.dot( + x_tile.astype(jnp.float32), w_tile.astype(jnp.float32) + ) + + xw_tile = jax.lax.fori_loop( + 0, + num_h_blocks, + h_body, + jnp.zeros((b_block_size, v_block_size), dtype=jnp.float32), + ) + + # Mask zero-padded columns to -inf so they don't inflate the logsumexp. + # For non-padded chunks this is a no-op (all col_idx < v_dim). + col_idx = jnp.arange(v_block_size) + v_start # (v_block_size,) + xw_masked = jnp.where(col_idx[None, :] < v_dim, xw_tile, -jnp.inf) + + # Per-token logsumexp over this V chunk. Combined across V outside the kernel + # via logsumexp(tile_lse, axis=-1) to get the global per-token LSE. + tile_lse = jax.nn.logsumexp(xw_masked, axis=-1) # (b_block_size,) + tile_lse_ref.store(tile_lse[:, None]) + + # Correct-class logit for tokens whose label falls in this V chunk. + # Uses unmasked xw_tile (not xw_masked) to avoid 0 * -inf = NaN. + # one_hot returns 0 for labels outside [0, v_block_size), so tokens + # whose label is in a different V chunk (or in the padded region) contribute 0. + labels_local = labels_ref.load().astype(jnp.int32) - v_start + one_hot = jax.nn.one_hot( + labels_local, num_classes=v_block_size, dtype=jnp.float32 + ) + correct_logit = jnp.sum(one_hot * xw_tile, axis=-1) # (b_block_size,) + correct_logit_ref.store(correct_logit[:, None]) + + +@partial( + jax.jit, + static_argnames=[ + "b_block_size", + "h_block_size", + "v_block_size", + "reduction", + "num_warps", + ], +) +def linear_softmax_cross_entropy_loss_fwd_pallas_triton( + x: Real[Array, "B H"], + labels: Integer[Array, "B"], + w: Real[Array, "H V"], + *, + b_block_size: int = 32, + h_block_size: int = 64, + v_block_size: int = 128, + reduction: Literal["sum", "mean"] = "sum", + num_warps: int = 4, +) -> tuple[Real[Scalar, ""], Real[Array, "B"]]: + """Fused matmul + cross-entropy loss forward pass on GPU via Pallas/Triton. + + Tiles over (B, V) with an inner H loop, so the (b_block, v_block) logit tile + lives only in registers -- no (B, V) materialisation in HBM. + + Args: + x: Hidden states, shape (B, H). + labels: Integer token indices, shape (B,). + w: LM head weight matrix, shape (H, V). + b_block_size: Tile size over the B (batch/token) dimension. B must be + divisible by b_block_size. + h_block_size: Tile size for the inner H accumulation loop. + v_block_size: Tile size over the V (vocab) dimension. V is zero-padded + to the next multiple of v_block_size inside this function; V need not + be divisible by v_block_size. + reduction: "sum" or "mean" over tokens. + num_warps: Triton warp count (tunable). + + Returns: + (loss, lse) where lse is the per-token log-sum-exp, saved as a residual + for the backward pass. + """ + _validate_inputs(x, labels, w, b_block_size, h_block_size, v_block_size) + + # bfloat16 is fine; float16 needs upcast to avoid precision loss. + if x.dtype == jnp.float16: + x = x.astype(jnp.float32) + if w.dtype == jnp.float16: + w = w.astype(jnp.float32) + + b_dim, h_dim = x.shape + v_dim = w.shape[1] + num_b_blocks = pl.cdiv(b_dim, b_block_size) + num_h_blocks = pl.cdiv(h_dim, h_block_size) + num_v_blocks = pl.cdiv(v_dim, v_block_size) + v_padded = num_v_blocks * v_block_size + + # Pad w so its V dimension is an exact multiple of v_block_size. + # Padded columns are zero; the kernel masks them to -inf before logsumexp. + if v_padded != v_dim: + w = jnp.pad(w, ((0, 0), (0, v_padded - v_dim))) + + kernel = partial( + _lce_fwd_kernel, + b_block_size=b_block_size, + h_block_size=h_block_size, + num_h_blocks=num_h_blocks, + v_block_size=v_block_size, + v_dim=v_dim, + ) + + # Outputs are (B, num_v_blocks): one value per token per V chunk. + # Combining across V happens outside the kernel in plain JAX. + tile_lse, correct_logit_contrib = block.pallas_call( + kernel, + name="pallas_triton_lce_fwd", + grid=(num_b_blocks, num_v_blocks), + out_shape=( + jax.ShapeDtypeStruct((b_dim, num_v_blocks), jnp.float32), + jax.ShapeDtypeStruct((b_dim, num_v_blocks), jnp.float32), + ), + in_specs=( + pl.BlockSpec((b_block_size, h_dim), lambda b, v: (b, 0)), # x + pl.BlockSpec((b_block_size,), lambda b, v: (b,)), # labels + pl.BlockSpec((h_dim, v_block_size), lambda b, v: (0, v)), # w (padded) + ), + out_specs=( + pl.BlockSpec((b_block_size, 1), lambda b, v: (b, v)), # tile_lse + pl.BlockSpec((b_block_size, 1), lambda b, v: (b, v)), # correct_logit + ), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), + )(x, labels, w) + + # tile_lse[b, v] = logsumexp(x[b,:] @ w[:, v*vb:(v+1)*vb]) + # Global per-token LSE: logsumexp over V chunks (numerically stable). + lse = jax.nn.logsumexp(tile_lse, axis=-1) # (B,) + + # correct_logit_contrib[b, v] = xw[b, labels[b]] if labels[b] in v-chunk, else 0. + # Exactly one V chunk is non-zero per token. + correct_logit = jnp.sum(correct_logit_contrib, axis=-1) # (B,) + + per_token_loss = -correct_logit + lse # (B,) NLL per token + + if reduction == "sum": + loss = jnp.sum(per_token_loss) + else: # mean + loss = jnp.mean(per_token_loss) + + return loss.astype(jnp.float32), lse diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel_test.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel_test.py new file mode 100644 index 00000000..026d4ac9 --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_kernel_test.py @@ -0,0 +1,146 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for pallas_triton_kernel.py (forward pass).""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from tokamax._src.ops.linear_softmax_cross_entropy_loss import pallas_triton_kernel as kernel +from tokamax._src.ops.linear_softmax_cross_entropy_loss import reference +from tokamax._src.ops.linear_softmax_cross_entropy_loss import test_utils + + +class PallasTritonLceFwdKernelTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if jax.default_backend() != "gpu": + self.skipTest("GPU-only test.") + + @parameterized.named_parameters( + dict( + testcase_name="small_sum", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="sum", + b_block_size=32, + h_block_size=64, + v_block_size=128, + ), + dict( + testcase_name="small_mean", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="mean", + b_block_size=32, + h_block_size=64, + v_block_size=128, + ), + dict( + testcase_name="medium_sum", + b_dim=128, + h_dim=256, + v_dim=512, + reduction="sum", + b_block_size=32, + h_block_size=64, + v_block_size=128, + ), + dict( + testcase_name="medium_mean", + b_dim=128, + h_dim=256, + v_dim=512, + reduction="mean", + b_block_size=32, + h_block_size=64, + v_block_size=128, + ), + dict( + testcase_name="bfloat16", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="sum", + b_block_size=32, + h_block_size=64, + v_block_size=128, + dtype=jnp.bfloat16, + ), + dict( + # V=300 is not divisible by v_block_size=128; last chunk is padded. + testcase_name="v_not_divisible_by_block", + b_dim=64, + h_dim=128, + v_dim=300, + reduction="mean", + b_block_size=32, + h_block_size=64, + v_block_size=128, + ), + ) + def test_forward_matches_reference( + self, + b_dim, + h_dim, + v_dim, + reduction, + b_block_size, + h_block_size, + v_block_size, + num_warps=4, + dtype=jnp.float32, + ): + x, labels, w = test_utils.generate_random_data( + jax.random.key(0), b_dim, h_dim, v_dim, dtype=dtype + ) + + ref_loss, ref_lse = reference.linear_softmax_cross_entropy_loss_fwd_reference( + x, labels, w, reduction=reduction + ) + kernel_loss, kernel_lse = kernel.linear_softmax_cross_entropy_loss_fwd_pallas_triton( + x, labels, w, + b_block_size=b_block_size, + h_block_size=h_block_size, + v_block_size=v_block_size, + num_warps=num_warps, + reduction=reduction, + ) + + loss_atol = 5e-2 if dtype == jnp.bfloat16 else 1e-4 + loss_rtol = 5e-2 if dtype == jnp.bfloat16 else 1e-4 + # LSE tolerance: the conftest sets xla_gpu_enable_triton_gemm=False so the + # reference x@w uses cuBLAS while the kernel uses Triton tiled matmul; + # per-token LSE differs by ~1.2e-2 for float32 at medium dims (~4e-6 when + # both use Triton GEMM). + lse_atol = 5e-2 if dtype == jnp.bfloat16 else 2e-2 + lse_rtol = 5e-2 if dtype == jnp.bfloat16 else 2e-2 + + self.assertTrue( + jnp.allclose(ref_loss, kernel_loss, atol=loss_atol, rtol=loss_rtol), + msg=f"loss mismatch: ref={ref_loss:.6f} kernel={kernel_loss:.6f}", + ) + self.assertTrue( + jnp.allclose(ref_lse, kernel_lse, atol=lse_atol, rtol=lse_rtol), + msg=f"lse mismatch: max_diff={jnp.max(jnp.abs(ref_lse - kernel_lse)):.6f}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_test.py b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_test.py new file mode 100644 index 00000000..e50729fa --- /dev/null +++ b/tokamax/_src/ops/linear_softmax_cross_entropy_loss/pallas_triton_test.py @@ -0,0 +1,147 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""End-to-end tests for the Pallas/Triton linear softmax cross-entropy loss Op.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp + +from tokamax._src.ops.linear_softmax_cross_entropy_loss.base import ( + LinearSoftmaxCrossEntropyLoss, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_triton import ( + PallasTritonLinearSoftmaxCrossEntropyLoss, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.pallas_triton_config import ( + Config, +) +from tokamax._src.ops.linear_softmax_cross_entropy_loss.test_utils import ( + generate_random_data, +) + + +class PallasTritonLceOpTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if jax.default_backend() != "gpu": + self.skipTest("GPU-only test.") + + @parameterized.named_parameters( + dict( + testcase_name="small_sum", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="sum", + ), + dict( + testcase_name="small_mean", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="mean", + ), + dict( + testcase_name="medium_sum", + b_dim=128, + h_dim=256, + v_dim=512, + reduction="sum", + ), + dict( + testcase_name="medium_mean", + b_dim=128, + h_dim=256, + v_dim=512, + reduction="mean", + ), + dict( + testcase_name="bfloat16", + b_dim=64, + h_dim=128, + v_dim=256, + reduction="sum", + dtype=jnp.bfloat16, + ), + ) + def test_value_and_grad_matches_reference( + self, + b_dim, + h_dim, + v_dim, + reduction, + dtype=jnp.float32, + ): + x, labels, w = generate_random_data( + jax.random.key(42), b_dim, h_dim, v_dim, dtype=dtype + ) + config = Config(b_block_size=32, h_block_size=64, v_block_size=128) + + triton_op = PallasTritonLinearSoftmaxCrossEntropyLoss(config=config) + ref_op = LinearSoftmaxCrossEntropyLoss() + + # For bfloat16 compare against a float32-upcast reference (our kernel + # accumulates in float32 internally). + x_ref = x.astype(jnp.float32) if dtype == jnp.bfloat16 else x + w_ref = w.astype(jnp.float32) if dtype == jnp.bfloat16 else w + + kernel_loss, (kernel_x_grad, kernel_w_grad) = jax.value_and_grad( + triton_op, argnums=(0, 2) + )(x, labels, w, reduction=reduction) + + ref_loss, (ref_x_grad, ref_w_grad) = jax.value_and_grad( + ref_op, argnums=(0, 2) + )(x_ref, labels, w_ref, reduction=reduction) + + # The conftest sets xla_gpu_enable_triton_gemm=False so the reference op + # uses cuBLAS for x@w while our kernel uses Triton tiled matmul; differences + # of ~1e-2 are observed for float32 gradients at medium dims (~4e-6 when + # both use Triton GEMM). + atol = 2e-2 + rtol = 2e-2 + + self.assertTrue( + jnp.allclose( + ref_loss.astype(jnp.float32), + kernel_loss.astype(jnp.float32), + atol=atol, + rtol=rtol, + ), + msg=f"loss: ref={float(ref_loss):.6f} kernel={float(kernel_loss):.6f}", + ) + self.assertTrue( + jnp.allclose( + ref_x_grad.astype(jnp.float32), + kernel_x_grad.astype(jnp.float32), + atol=atol, + rtol=rtol, + ), + msg=f"x_grad max_diff={float(jnp.max(jnp.abs(ref_x_grad.astype(jnp.float32) - kernel_x_grad.astype(jnp.float32)))):.6f}", + ) + self.assertTrue( + jnp.allclose( + ref_w_grad.astype(jnp.float32), + kernel_w_grad.astype(jnp.float32), + atol=atol, + rtol=rtol, + ), + msg=f"w_grad max_diff={float(jnp.max(jnp.abs(ref_w_grad.astype(jnp.float32) - kernel_w_grad.astype(jnp.float32)))):.6f}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tokamax/benchmarks/benchmark_registry.pbtxt b/tokamax/benchmarks/benchmark_registry.pbtxt index 69dca3d2..914467a0 100644 --- a/tokamax/benchmarks/benchmark_registry.pbtxt +++ b/tokamax/benchmarks/benchmark_registry.pbtxt @@ -252,6 +252,131 @@ benchmarks { } +benchmarks { + name: "tokamax_linear_softmax_cross_entropy_loss" + description: "Runs the Tokamax linear_softmax_cross_entropy_loss benchmark." + owner: "Tokamax Team" + update_frequency_policy: QUARTERLY + workload { + action: "./ml_actions/benchmarking/actions/workload_executors/python" + action_inputs { key: "script_path" value: "tokamax/benchmarks/linear_softmax_cross_entropy_loss.py" } + action_inputs { key: "python_version" value: "3.11" } + } + + environment_configs { + id: "gpu-h100" + runner_label: "linux-x86-a3-8g-h100-1gpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda13.0-cudnn9.15@sha256:943892a4ab8e9b58a9c7b4297f170d3f28fcb1d479e9835190d49dafdbd2992a" + workload_action_inputs { key: "extras_hw" value: "cuda" } + } + + environment_configs { + id: "gpu-b200" + runner_label: "linux-x86-a4-224-b200-1gpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda13.0-cudnn9.15@sha256:943892a4ab8e9b58a9c7b4297f170d3f28fcb1d479e9835190d49dafdbd2992a" + workload_action_inputs { key: "extras_hw" value: "cuda" } + } + + environment_configs { + id: "tpu-v6e" + runner_label: "linux-x86-ct6e-44-1tpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest@sha256:43c523372c4b7f7ce649a1ff204b908727bd338353303c0444af34cb305e5832" + workload_action_inputs { key: "extras_hw" value: "tpu" } + workload_action_inputs { key: "runtime_flags_hw" value: "--skip_implementations=triton,mosaic_gpu" } + } + + environment_configs { + id: "tpu-v7" + runner_label: "linux-x86-tpu7x-56-1tpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest@sha256:43c523372c4b7f7ce649a1ff204b908727bd338353303c0444af34cb305e5832" + workload_action_inputs { key: "extras_hw" value: "tpu" } + workload_action_inputs { key: "runtime_flags_hw" value: "--skip_implementations=triton,mosaic_gpu" } + } + + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/default/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/default/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/triton/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/triton/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/mosaic_gpu/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/mosaic_gpu/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/xla/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/qwen3-8b/xla/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/default/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/default/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/triton/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/triton/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/mosaic_gpu/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/mosaic_gpu/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/xla/forward" + unit: "ms" + stats { stat: MEDIAN } + } + metrics { + name: "linear_softmax_cross_entropy_loss/llama3.1-8b/xla/forward_and_vjp" + unit: "ms" + stats { stat: MEDIAN } + } + +} + benchmarks { name: "tokamax_triangle_multiplication" description: "Runs the Tokamax triangle_multiplication benchmark." diff --git a/tokamax/benchmarks/linear_softmax_cross_entropy_loss.py b/tokamax/benchmarks/linear_softmax_cross_entropy_loss.py new file mode 100644 index 00000000..b06f8bc1 --- /dev/null +++ b/tokamax/benchmarks/linear_softmax_cross_entropy_loss.py @@ -0,0 +1,117 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Benchmarks for linear softmax cross-entropy loss.""" + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import tokamax +from tokamax.benchmarks import common + +_TENSORBOARD_OUTPUT_ENV_VAR = flags.DEFINE_string( + 'tensorboard_output_env_var', + 'TENSORBOARD_OUTPUT_DIR', + 'Environment variable to use to retrieve TensorBoard output directory.', +) +_SKIP_IMPLEMENTATIONS = flags.DEFINE_list( + 'skip_implementations', + [], + 'A comma-separated list of implementations to skip.', +) + + +# Representative shapes from real LLM vocabularies. +EXAMPLES = { + 'qwen3-8b': { + 'x': jax.ShapeDtypeStruct((4096, 4096), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((4096,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((4096, 151936), jnp.bfloat16), + 'reduction': 'mean', + }, + 'gemma3-4b': { + 'x': jax.ShapeDtypeStruct((4096, 2560), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((4096,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((2560, 262144), jnp.bfloat16), + 'reduction': 'mean', + }, + 'gemma3-7b': { + 'x': jax.ShapeDtypeStruct((4096, 3840), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((4096,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((3840, 262144), jnp.bfloat16), + 'reduction': 'mean', + }, + 'llama3.1-8b': { + 'x': jax.ShapeDtypeStruct((4096, 4096), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((4096,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((4096, 128256), jnp.bfloat16), + 'reduction': 'mean', + }, + 'deepseek-v3-671b': { + 'x': jax.ShapeDtypeStruct((8192, 7168), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((8192,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((7168, 128256), jnp.bfloat16), + 'reduction': 'mean', + }, + 'gpt-oss-120b': { + 'x': jax.ShapeDtypeStruct((4096, 2880), jnp.bfloat16), + 'labels': jax.ShapeDtypeStruct((4096,), jnp.int32), + 'weights':jax.ShapeDtypeStruct((2880, 201088), jnp.bfloat16), + 'reduction': 'mean', + }, +} + + +class LinearSoftmaxCrossEntropyLossBenchmark(parameterized.TestCase): + """Benchmarks for linear softmax cross-entropy loss.""" + + @parameterized.product( + implementation=(None, 'xla', 'triton', 'mosaic_gpu'), + benchmark_mode=('forward', 'forward_and_vjp'), + args_spec_name=tuple(EXAMPLES.keys()), + ) + def test_linear_softmax_cross_entropy_loss( + self, implementation, benchmark_mode, args_spec_name + ): + """Benchmarks the linear softmax cross-entropy loss operation.""" + if str(implementation) in _SKIP_IMPLEMENTATIONS.value: + self.skipTest(f'Skipping implementation {implementation}') + + if implementation in ('triton', 'mosaic_gpu') and jax.default_backend() != 'gpu': + self.skipTest(f'{implementation} implementation is GPU-only.') + + example = EXAMPLES[args_spec_name] | {'implementation': implementation} + fn, args = tokamax.standardize_function( + tokamax.linear_softmax_cross_entropy_loss, + kwargs=example, + mode=benchmark_mode, # pytype: disable=wrong-arg-types + ) + fn = jax.jit(fn) + res = tokamax.benchmark(fn, args) + + common.write_tensorboard_logs( + tensorboard_output=_TENSORBOARD_OUTPUT_ENV_VAR.value, + value=res.evaluation_times_ms, + metric_tag=( + f'linear_softmax_cross_entropy_loss/{args_spec_name}' + f'/{implementation or "default"}/{benchmark_mode}' + ), + ) + + +if __name__ == '__main__': + absltest.main()