From 6dda44a7ab4087bbd53d7fdeff19a848bbd5a542 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Feb 2026 13:18:58 -0500 Subject: [PATCH 1/8] scattermoe lora support --- requirements.txt | 2 +- src/axolotl/integrations/kernels/args.py | 12 + .../integrations/kernels/libs/__init__.py | 0 .../kernels/libs/scattermoe_lora/__init__.py | 18 + .../libs/scattermoe_lora/kernels/__init__.py | 12 + .../libs/scattermoe_lora/kernels/lora_ops.py | 1677 +++++++++++++++++ .../libs/scattermoe_lora/kernels/ops.py | 645 +++++++ .../libs/scattermoe_lora/kernels/single.py | 95 + .../kernels/libs/scattermoe_lora/layers.py | 413 ++++ .../kernels/libs/scattermoe_lora/lora_ops.py | 99 + .../libs/scattermoe_lora/parallel_experts.py | 255 +++ .../scattermoe_lora/parallel_linear_lora.py | 474 +++++ src/axolotl/integrations/kernels/plugin.py | 16 +- 13 files changed, 3712 insertions(+), 6 deletions(-) create mode 100644 src/axolotl/integrations/kernels/libs/__init__.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py create mode 100644 src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py diff --git a/requirements.txt b/requirements.txt index 09b1f625b0..710e24d718 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ datasets==4.5.0 deepspeed>=0.18.3 trl==0.28.0 hf_xet==1.2.0 -kernels==0.11.5 +kernels==0.12.1 trackio>=0.16.1 typing-extensions>=4.15.0 diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index 66d6b6d531..78050ddc92 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -33,3 +33,15 @@ def check_experts_implementation(cls, data): data["experts_implementation"] = "eager" return data + + @model_validator(mode="before") + @classmethod + def disable_mlp_kernel_scattermoe(cls, data): + if data.get("use_scattermoe") is True: + if data.get("lora_mlp_kernel") is True: + LOG.warning( + "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues." + ) + data["mlp_kernel"] = False + + return data diff --git a/src/axolotl/integrations/kernels/libs/__init__.py b/src/axolotl/integrations/kernels/libs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py new file mode 100644 index 0000000000..f5148634e6 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +from . import layers +from .lora_ops import ParallelExperts +from .parallel_experts import flatten_sort_count, parallel_linear +from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora + +__all__ = [ + "layers", + "ParallelExperts", + "flatten_sort_count", + "parallel_linear", + "ScatterMoELoRA", + "parallel_linear_lora", + "lora_ops", +] diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py new file mode 100644 index 0000000000..eb502db712 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors +# Adapted from https://github.com/shawntan/scattermoe +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE +# +# Modifications and LoRA adaptation Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +from . import lora_ops, ops + +__all__ = ["ops", "lora_ops"] diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py new file mode 100644 index 0000000000..77a4952487 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -0,0 +1,1677 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Fused ScatterMoE + LoRA Triton Kernels +======================================= + +Provides fused forward and backward kernels for ScatterMoE with LoRA adapters. + +Forward: Y = X @ W + scaling * (X @ A^T) @ B^T +Backward (LoRA training, W frozen): + - dX = dY @ W^T + scaling * (dY @ B) @ A (input gradient) + - dA = scaling * (dY @ B)^T @ X (LoRA A gradient) + - dB = scaling * dY^T @ (X @ A^T) (LoRA B gradient) + +LoRA weight layout (from PEFT ParamWrapper): + - A: [r*E, K] -- for expert e, rows [e*r : (e+1)*r] give A_e of shape [r, K] + - B: [N, r*E] -- for expert e, cols [e*r : (e+1)*r] give B_e of shape [N, r] + +Key design decisions: + - The forward kernel fuses X@W and X@A^T in the same K-loop for data reuse on X, + then computes (X@A^T) @ B^T in the epilogue. + - The backward dA/dB kernel operates on grouped (expert-contiguous) data and + iterates over tokens per expert, accumulating gradients in registers. + - R (LoRA rank) is a tl.constexpr, allowing tl.arange(0, R). We pad R to a + power-of-2 for Triton tile compatibility; typical ranks (4, 8, 16, 32, 64) + already satisfy this. +""" + +from itertools import product +from typing import Optional + +import torch +import triton +import triton.language as tl + +# ============================================================================= +# Configuration +# ============================================================================= + +BLOCK_M = 128 +ALLOW_TF32 = True + + +def _next_power_of_2(n: int) -> int: + """Round up to next power of 2.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + return n + 1 + + +# Triton tl.dot requires minimum tile dimensions of 16 on modern GPUs. +MIN_TRITON_DOT_SIZE = 16 + + +def _block_r_for_rank(r: int) -> int: + """Compute BLOCK_R: next power-of-2 >= max(r, MIN_TRITON_DOT_SIZE).""" + return _next_power_of_2(max(r, MIN_TRITON_DOT_SIZE)) + + +# ============================================================================= +# Token Rounding: pad expert counts to BLOCK_M multiples +# ============================================================================= + + +def round_expert_counts( + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + block_m: int = BLOCK_M, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Pad each expert's token count to a multiple of block_m to eliminate + partial-tile waste in the backward kernel. + + Padding is done by duplicating the last valid token index for each expert. + The kernel's M_mask = M_idx < real_end_idx masks these padding entries, so + correctness is preserved (they contribute 0 to the accumulation via other=0.0). + + This only helps the backward dA/dB kernel where per-expert iteration is + explicit. The forward scatter2scatter kernel handles partial tiles via masking. + + Args: + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + expert_offsets: Cumulative token counts per expert [E] + E: Number of experts + block_m: Block size for token dimension (default: BLOCK_M) + + Returns: + padded_expert_idxs: [M_padded] expert assignments with padding + padded_scattered_idxs: [M_padded] original indices with padding + padded_offsets: [E] cumulative padded counts (for kernel iteration range) + real_offsets: [E] original cumulative counts (for M_mask in kernel) + """ + device = sorted_expert_idxs.device + + # Compute per-expert counts + counts = torch.zeros(E, dtype=torch.int64, device=device) + prev = 0 + for e in range(E): + curr = expert_offsets[e].item() + counts[e] = curr - prev + prev = curr + + # Round up each count to multiple of block_m + padded_counts = ((counts + block_m - 1) // block_m) * block_m + # Experts with 0 tokens stay at 0 + padded_counts = torch.where( + counts > 0, padded_counts, torch.zeros_like(padded_counts) + ) + total_padded = padded_counts.sum().item() + + padded_expert_idxs = torch.empty( + total_padded, dtype=sorted_expert_idxs.dtype, device=device + ) + padded_scattered_idxs = torch.empty( + total_padded, dtype=sorted_scattered_idxs.dtype, device=device + ) + + src_offset = 0 + dst_offset = 0 + for e in range(E): + count = counts[e].item() + padded_count = padded_counts[e].item() + + if count > 0: + # Copy original tokens + padded_expert_idxs[dst_offset : dst_offset + count] = sorted_expert_idxs[ + src_offset : src_offset + count + ] + padded_scattered_idxs[dst_offset : dst_offset + count] = ( + sorted_scattered_idxs[src_offset : src_offset + count] + ) + + # Pad with last valid token (masked out by kernel via M_mask) + if padded_count > count: + padded_expert_idxs[dst_offset + count : dst_offset + padded_count] = ( + sorted_expert_idxs[src_offset + count - 1] + ) + padded_scattered_idxs[ + dst_offset + count : dst_offset + padded_count + ] = sorted_scattered_idxs[src_offset + count - 1] + + src_offset += count + dst_offset += padded_count + + # Padded offsets: cumulative padded counts (for iteration range in kernel) + padded_offsets = padded_counts.cumsum(-1).to(expert_offsets.dtype) + # Real offsets: original cumulative counts (for M_mask in kernel) + real_offsets = expert_offsets.clone() + + return padded_expert_idxs, padded_scattered_idxs, padded_offsets, real_offsets + + +# ============================================================================= +# Autotuning: SMEM estimation and config pruning +# ============================================================================= + +_SMEM_CAPACITY: int | None = None + + +def _get_smem_capacity() -> int: + """Get device shared memory capacity (bytes). Cached after first call.""" + global _SMEM_CAPACITY + if _SMEM_CAPACITY is None: + props = triton.runtime.driver.active.utils.get_device_properties( + torch.cuda.current_device() + ) + _SMEM_CAPACITY = props["max_shared_mem"] + return _SMEM_CAPACITY + + +def _estimate_smem_usage( + num_stages: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, dtype_bytes: int = 2 +) -> int: + """Estimate shared memory in bytes for a GEMM-style tile. + + Formula: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + Multiply by dtype_bytes (2 for fp16/bf16). + """ + return ( + num_stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + ) * dtype_bytes + + +# Conservative margin (bytes) subtracted from SMEM capacity to account for +# estimation inaccuracies and kernel overhead (registers spilled to SMEM, etc.) +_SMEM_SLACK = 10_000 + + +# ============================================================================= +# Forward Kernel: scatter2scatter with fused LoRA +# ============================================================================= + + +@triton.jit +def _compute_expert_block_lora( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + # Base weight + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + # LoRA weights + A_ptr, + stride_ar, + stride_ak, # A: [r*E, K], stride_ar = stride for r*E dim, stride_ak = stride for K dim + B_ptr, + stride_bn, + stride_br, # B: [N, r*E], stride_bn = stride for N dim, stride_br = stride for r*E dim + # Dimensions + K, + ACTUAL_R: tl.constexpr, # True LoRA rank (for indexing into weight arrays) + acc, + no_k_mask, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + scaling, + allow_tf32: tl.constexpr, +): + """ + Compute Y_block = X_block @ W_e + scaling * (X_block @ A_e^T) @ B_e^T + + for tokens in this M-block assigned to expert E_idx. + + ACTUAL_R is the true LoRA rank used for indexing into A[e*r:(e+1)*r, :]. + BLOCK_R >= ACTUAL_R is the padded tile dimension (must be >= 16 for tl.dot). + When BLOCK_R > ACTUAL_R, loads are masked on the R dimension. + """ + K_block = tl.arange(0, BLOCK_K) + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R # Mask for padding when BLOCK_R > ACTUAL_R + + # Base weight pointers: W[E_idx, :, :] is [K, N], load [BLOCK_K, BLOCK_N] + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + ) + + # LoRA A pointers: A[e*ACTUAL_R:(e+1)*ACTUAL_R, :] for expert e, shape [r, K] + A_expert_offset = E_idx * ACTUAL_R + A_blk_ptrs = ( + A_ptr + + (A_expert_offset + R_block)[:, None] * stride_ar + + K_block[None, :] * stride_ak + ) + + iters = tl.cdiv(K, BLOCK_K) + + # Accumulator for X @ A^T: [BLOCK_M, BLOCK_R] + xa_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + + for i in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None], other=0.0) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :], other=0.0) + a = tl.load( + A_blk_ptrs, mask=R_mask[:, None], other=0.0 + ) # [BLOCK_R, BLOCK_K], masked on R dim + else: + K_mask = (i * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :], other=0.0) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0) + a = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + + # Base: acc += X @ W ([M, K] @ [K, N] -> [M, N]) + acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) + + # LoRA: xa_acc += X @ A^T ([M, K] @ [K, R] -> [M, R]) + xa_acc = tl.dot(x, tl.trans(a), xa_acc, allow_tf32=allow_tf32) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + A_blk_ptrs += BLOCK_K * stride_ak + + # Epilogue: load B[e] and compute (X @ A^T) @ B^T + # B[e] is B[:, e*ACTUAL_R:(e+1)*ACTUAL_R], shape [N, r]. Load [BLOCK_N, BLOCK_R]. + B_expert_offset = E_idx * ACTUAL_R + B_blk_ptrs = ( + B_ptr + + N_block[:, None] * stride_bn + + (B_expert_offset + R_block)[None, :] * stride_br + ) + b = tl.load( + B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0 + ) # [BLOCK_N, BLOCK_R] + + # Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16) + # Both operands must match; cast to float32 (accumulator type) for precision. + b_f32 = b.to(tl.float32) + + # (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N] + lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32) + + acc += scaling * lora_out + return acc + + +def _scatter2scatter_lora_configs(): + """Generate forward kernel autotune configs. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_N: {32, 64, 128, 256} + BLOCK_K: {32, 64, 128} + num_warps: {4, 8} + num_stages: {2, 3, 4, 5} + + BLOCK_M is fixed at 128 (module-level constant, not autotuned in the + scatter2scatter pattern). + """ + configs = [] + for block_n, block_k, warps, stages in product( + [32, 64, 128, 256], # BLOCK_N + [32, 64, 128], # BLOCK_K + [4, 8], # num_warps + [2, 3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_N": block_n, "BLOCK_K": block_k}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_fwd_configs(configs, named_args, **kwargs): + """Prune forward configs based on SMEM capacity. + + The forward kernel inner loop loads three tiles per pipeline stage: + X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K]. + The base estimate only accounts for X and W. We add: + - A tile [BLOCK_R, BLOCK_K] per pipeline stage (loaded in the inner loop) + - B tile [BLOCK_N, BLOCK_R] loaded once in the epilogue + - Extra headroom for compiler overhead (register spills, metadata) + """ + smem_cap = _get_smem_capacity() + + # Get BLOCK_R from named_args if available, else assume worst case + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_n = config.kwargs["BLOCK_N"] + block_k = config.kwargs["BLOCK_K"] + # Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k) + # A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop + smem_lora_loop = config.num_stages * block_r * block_k * 2 + # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue + smem_lora_epilogue = block_n * block_r * 2 + smem = smem_base + smem_lora_loop + smem_lora_epilogue + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_scatter2scatter_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_fwd_configs}, +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter_lora( + # Input/Output + X_ptr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + Y_ptr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + # Bias + Bias_ptr, + stride_bias_e: tl.constexpr, + stride_bias_n: tl.constexpr, + # LoRA weights + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Routing + grouped_idx_ptr, + expert_idxs_ptr, + # Dimensions + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + ACTUAL_R: tl.constexpr, # True LoRA rank (for weight indexing) + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + # Config + ACC_TYPE: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Fused scatter2scatter with LoRA: Y = X @ W + scaling * (X @ A^T) @ B^T + bias + """ + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = NO_K_MASK + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + acc = _compute_expert_block_lora( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + LA_ptr, + stride_la_r, + stride_la_k, + LB_ptr, + stride_lb_n, + stride_lb_r, + K, + ACTUAL_R, + acc, + no_k_mask, + BLOCK_M, + BLOCK_K, + BLOCK_N, + BLOCK_R, + scaling, + allow_tf32=allow_tf32, + ) + + # Add bias if present + if Bias_ptr is not None: + B_blk_ptrs = ( + Bias_ptr + + E_idxs[:, None] * stride_bias_e + + N_block[None, :] * stride_bias_n + ) + acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + # Store output + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter_lora( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + b: Optional[torch.Tensor] = None, + x_grouped: bool = False, + y_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + + Args: + X: Input [M, K] or [M*k, K] if x_grouped + W: Expert weights [E, K, N] + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + k: Fan-out (top-k) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + scaling: LoRA scaling factor (alpha/r) + b: Optional bias [E, N] + x_grouped: Input pre-grouped by expert + y_grouped: Keep output grouped + out: Optional pre-allocated output buffer + + Returns: + Y: Output [M*k, N] + """ + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + + E = W.size(0) + K = W.size(1) + N = W.size(2) + R = lora_A.size(0) // E + + # Pad R to power of 2 for Triton tile size + BLOCK_R = _block_r_for_rank(R) + + L_scattered = sorted_expert_idxs.size(0) + + if out is None: + output = torch.empty((L_scattered, N), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == N + output = out + + def grid(META): + return ( + triton.cdiv(L_scattered, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + if b is None: + stride_be = stride_bn = 0 + b_ptr = None + else: + stride_be, stride_bn = b.stride() + b_ptr = b + + _scatter2scatter_lora[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + output, + output.stride(0), + output.stride(1), + b_ptr, + stride_be, + stride_bn, + # A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride + lora_A, + lora_A.stride(0), + lora_A.stride(1), + # B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride + lora_B, + lora_B.stride(0), + lora_B.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + FAN_OUT=k, + M=X.size(0), + K=K, + N=N, + E=E, + ACTUAL_R=R, # True LoRA rank for weight indexing + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16) + ACC_TYPE=tl.float32, + scaling=scaling, + allow_tf32=ALLOW_TF32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + return output + + +# ============================================================================= +# Backward Kernel: Fused dX = dY @ W^T + scaling * (dY @ B) @ A +# ============================================================================= + + +@triton.jit +def _compute_expert_block_lora_dX( + E_idx, + E_mask, + M_in_idx, + K_block, + K_mask, + # Input: DY (gradient w.r.t. output) + DY_ptr, + stride_dym, + stride_dyn, + # Base weight W^T: we load W[e] as [K, N] and index as W^T[e] = [N, K] + W_ptr, + stride_we, + stride_wk, + stride_wn, + # LoRA weights + A_ptr, + stride_ar, + stride_ak, # A: [r*E, K] + B_ptr, + stride_bn, + stride_br, # B: [N, r*E] + # Dimensions + N, + ACTUAL_R: tl.constexpr, + acc, + no_n_mask, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, +): + """ + Compute dX_block = DY_block @ W_e^T + scaling * (DY_block @ B_e) @ A_e + + for tokens in this M-block assigned to expert E_idx. + + Inner loop over N dimension (reduction dim for dY @ W^T and dY @ B). + Output dimension is K. + Epilogue computes (dY @ B) @ A. + + Transpose mapping from forward: + Forward: X@W (K-loop), X@A^T (K-loop), (X@A^T)@B^T (epilogue) + Backward: DY@W^T (N-loop), DY@B (N-loop), (DY@B)@A (epilogue) + """ + N_block = tl.arange(0, BLOCK_N) + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + + # DY pointers: DY is [M_total, N], load [BLOCK_M, BLOCK_N] + DY_blk_ptrs = ( + DY_ptr + M_in_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + + # W^T pointers: W[e] is [K, N], W^T[e] is [N, K]. We load W^T as [BLOCK_N, BLOCK_K]. + # W stored as [E, K, N], so W^T[e][n, k] = W[e][k, n] = W_ptr + e*stride_we + k*stride_wk + n*stride_wn + # As [BLOCK_N, BLOCK_K] tile: row=n, col=k + WT_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + N_block[:, None] * stride_wn # row = n dimension + + K_block[None, :] * stride_wk + ) # col = k dimension + + # B pointers: B[e] is B[:, e*R:(e+1)*R], shape [N, R]. Load [BLOCK_N, BLOCK_R]. + B_expert_offset = E_idx * ACTUAL_R + B_blk_ptrs = ( + B_ptr + + N_block[:, None] * stride_bn + + (B_expert_offset + R_block)[None, :] * stride_br + ) + + iters = tl.cdiv(N, BLOCK_N) + + # Accumulator for DY @ B: [BLOCK_M, BLOCK_R] + dy_b_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + + for i in range(iters): + if no_n_mask: + dy = tl.load(DY_blk_ptrs, mask=E_mask[:, None], other=0.0) + wt = tl.load(WT_blk_ptrs, mask=K_mask[None, :], other=0.0) + b = tl.load(B_blk_ptrs, mask=R_mask[None, :], other=0.0) + else: + N_mask_iter = (i * BLOCK_N + N_block) < N + dy = tl.load( + DY_blk_ptrs, mask=E_mask[:, None] & N_mask_iter[None, :], other=0.0 + ) + wt = tl.load( + WT_blk_ptrs, mask=N_mask_iter[:, None] & K_mask[None, :], other=0.0 + ) + b = tl.load( + B_blk_ptrs, mask=N_mask_iter[:, None] & R_mask[None, :], other=0.0 + ) + + # Base: acc += DY @ W^T ([M, N] @ [N, K] -> [M, K]) + acc = tl.dot(dy, wt, acc, allow_tf32=allow_tf32) + + # LoRA: dy_b_acc += DY @ B ([M, N] @ [N, R] -> [M, R]) + dy_b_acc = tl.dot(dy, b, dy_b_acc, allow_tf32=allow_tf32) + + DY_blk_ptrs += BLOCK_N * stride_dyn + WT_blk_ptrs += BLOCK_N * stride_wn + B_blk_ptrs += BLOCK_N * stride_bn + + # Epilogue: load A[e] and compute (DY @ B) @ A + # A[e] is A[e*R:(e+1)*R, :], shape [R, K]. Load [BLOCK_R, BLOCK_K]. + A_expert_offset = E_idx * ACTUAL_R + A_blk_ptrs = ( + A_ptr + + (A_expert_offset + R_block)[:, None] * stride_ar + + K_block[None, :] * stride_ak + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + + # Cast to float32 for precision + a_f32 = a_e.to(tl.float32) + + # (DY @ B) @ A: [M, R] @ [R, K] -> [M, K] + lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32) + + acc += scaling * lora_dx + return acc + + +def _scatter2scatter_lora_dX_configs(): + """Generate backward dX kernel autotune configs. + + The inner loop is over N (not K as in forward). The output dimension is K. + So BLOCK_K tiles the output and BLOCK_N tiles the reduction. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_K: {32, 64, 128, 256} (output tile) + BLOCK_N: {32, 64, 128, 256} (reduction tile) + num_warps: {4, 8} + num_stages: {2, 3, 4, 5} + """ + configs = [] + for block_k, block_n, warps, stages in product( + [32, 64, 128, 256], # BLOCK_K (output dimension) + [32, 64, 128, 256], # BLOCK_N (reduction dimension) + [4, 8], # num_warps + [2, 3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_K": block_k, "BLOCK_N": block_n}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_dX_configs(configs, named_args, **kwargs): + """Prune backward dX configs based on SMEM capacity. + + The dX kernel inner loop loads three tiles per pipeline stage: + DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R]. + The base estimate only accounts for DY and W^T. We add: + - B tile [BLOCK_N, BLOCK_R] per pipeline stage (loaded in the inner loop) + - A tile [BLOCK_R, BLOCK_K] loaded once in the epilogue + - Extra headroom for compiler overhead (register spills, metadata) + """ + smem_cap = _get_smem_capacity() + + # Get BLOCK_R from named_args if available, else assume worst case + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_k = config.kwargs["BLOCK_K"] + block_n = config.kwargs["BLOCK_N"] + # Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K + smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n) + # B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop + smem_lora_loop = config.num_stages * block_n * block_r * 2 + # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue + smem_lora_epilogue = block_r * block_k * 2 + smem = smem_base + smem_lora_loop + smem_lora_epilogue + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_scatter2scatter_lora_dX_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_dX_configs}, +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter_lora_dX( + # Input: DY (gradient w.r.t. output, grouped) + DY_ptr, + stride_dym: tl.constexpr, + stride_dyn: tl.constexpr, + # Base weight: W [E, K, N] (we compute DY @ W^T) + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + # Output: dX + DX_ptr, + stride_dxm: tl.constexpr, + stride_dxk: tl.constexpr, + # LoRA weights + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Routing + grouped_idx_ptr, + expert_idxs_ptr, + # Dimensions + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + ACTUAL_R: tl.constexpr, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, + # Config + ACC_TYPE: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, + dy_grouped: tl.constexpr, + dx_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A + + DY is in expert-grouped order (x_grouped=True). + dX is output in ungrouped or grouped order based on dx_grouped. + + Grid: (cdiv(M_total, BLOCK_M) * cdiv(K, BLOCK_K),) + """ + pid = tl.program_id(axis=0) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + M_block_id = pid // K_BLOCK_COUNT + K_block_id = pid % K_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + M_boundary_mask = M_block < (FAN_OUT * M) + + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_n_mask = NO_N_MASK + + acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=ACC_TYPE) + + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + if dy_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + acc = _compute_expert_block_lora_dX( + E_idx, + E_mask, + M_in_idx, + K_block, + K_mask, + DY_ptr, + stride_dym, + stride_dyn, + W_ptr, + stride_we, + stride_wk, + stride_wn, + LA_ptr, + stride_la_r, + stride_la_k, + LB_ptr, + stride_lb_n, + stride_lb_r, + N, + ACTUAL_R, + acc, + no_n_mask, + BLOCK_M, + BLOCK_N, + BLOCK_K, + BLOCK_R, + scaling, + allow_tf32=allow_tf32, + ) + + # Store output + if dx_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + DX_blk_ptrs = DX_ptr + ( + M_out_idx[:, None] * stride_dxm + K_block[None, :] * stride_dxk + ) + tl.store(DX_blk_ptrs, acc, mask=M_boundary_mask[:, None] & K_mask[None, :]) + + +def scatter2scatter_lora_dX( + DY: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + dy_grouped: bool = True, + dx_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A + + Replaces the separate: + 1. base_ops.scatter2scatter(DY, W^T, x_grouped=True, ...) + 2. _compute_lora_input_grad(DY, A, B, ...) + + Args: + DY: Gradient w.r.t. output [M*k, N] (grouped by expert) + W: Expert weights [E, K, N] (NOT transposed — kernel handles W^T internally) + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + k: Fan-out (top-k) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + scaling: LoRA scaling factor + dy_grouped: Whether DY is in grouped (expert-sorted) order (default True) + dx_grouped: Whether to output dX in grouped order (default False) + out: Optional pre-allocated output buffer + + Returns: + dX: Input gradient [M*k, K] + """ + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + + E = W.size(0) + K = W.size(1) + N = W.size(2) + R = lora_A.size(0) // E + + BLOCK_R = _block_r_for_rank(R) + + L_scattered = sorted_expert_idxs.size(0) + + # M for the kernel is DY.size(0) when dy_grouped, else the original M + if dy_grouped: + M = DY.size(0) + fan_out = 1 # DY is already expanded + else: + M = DY.size(0) + fan_out = k + + if out is None: + output = torch.empty((L_scattered, K), device=DY.device, dtype=DY.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == K + output = out + + def grid(META): + return ( + triton.cdiv(L_scattered, META["BLOCK_M"]) * triton.cdiv(K, META["BLOCK_K"]), + ) + + _scatter2scatter_lora_dX[grid]( + DY, + DY.stride(0), + DY.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + output, + output.stride(0), + output.stride(1), + lora_A, + lora_A.stride(0), + lora_A.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + FAN_OUT=fan_out, + M=M, + K=K, + N=N, + E=E, + ACTUAL_R=R, + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + ACC_TYPE=tl.float32, + scaling=scaling, + allow_tf32=ALLOW_TF32, + dy_grouped=dy_grouped, + dx_grouped=dx_grouped, + ) + + return output + + +# ============================================================================= +# Backward Kernel: LoRA gradient computation (dA, dB) +# ============================================================================= + + +def _group_bwd_lora_configs(): + """Generate backward (dA/dB) kernel autotune configs. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_M: {32, 64, 128, 256} (token-loop tile) + BLOCK_K: {32, 64, 128, 256} + BLOCK_N: {32, 64, 128, 256} + num_warps: {4, 8} + num_stages: {2, 3, 4, 5} + + The backward kernel also uses BLOCK_R (from LoRA rank), but that is + determined by the rank and not autotunable. + """ + configs = [] + for block_m, block_k, block_n, warps, stages in product( + [32, 64, 128, 256], # BLOCK_M + [32, 64, 128, 256], # BLOCK_K + [32, 64, 128, 256], # BLOCK_N + [4, 8], # num_warps + [2, 3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_bwd_lora_configs(configs, named_args, **kwargs): + """Prune backward configs based on SMEM capacity. + + The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N] + in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] + for the full expert. We estimate SMEM based on the dominant terms. + """ + smem_cap = _get_smem_capacity() + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_m = config.kwargs["BLOCK_M"] + block_k = config.kwargs["BLOCK_K"] + block_n = config.kwargs["BLOCK_N"] + # Inner loop loads X[M,K] and DY[M,N], pipeline over M iterations + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k) + # A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert + smem_lora = (block_r * block_k + block_n * block_r) * 2 + smem = smem_base + smem_lora + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_group_bwd_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_bwd_lora_configs}, + reset_to_zero=["DLA_ptr", "DLB_ptr"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _group_bwd_lora( + # Inputs + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, + # LoRA weights (needed for cross-terms) + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Gradient outputs + DLA_ptr, + stride_dla_r, + stride_dla_k, + DLB_ptr, + stride_dlb_n, + stride_dlb_r, + # Expert offsets + expert_offsets_ptr, + # Dimensions + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, # True LoRA rank (for weight indexing) + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + scaling, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Compute LoRA gradients for each expert on grouped data. + + Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N)) + + For expert e: + dA[e] = scaling * (dY @ B[e])^T @ X -> [r, K], accumulate over M tokens + dB[e] = scaling * dY^T @ (X @ A[e]^T) -> [N, r], accumulate over M tokens + + ACTUAL_R is the true LoRA rank. BLOCK_R >= ACTUAL_R is padded for tl.dot min size. + """ + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + # Get expert's token range from cumulative offsets + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R # Mask for padding + + lora_offset = E_idx * ACTUAL_R + + # Load B[e]: [BLOCK_N, BLOCK_R] (masked on R and N, other=0 for padding) + B_blk_ptrs = ( + LB_ptr + + N_block[:, None] * stride_lb_n + + (lora_offset + R_block)[None, :] * stride_lb_r + ) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0) + + # Load A[e]: [BLOCK_R, BLOCK_K] (masked on R and K, other=0 for padding) + A_blk_ptrs = ( + LA_ptr + + (lora_offset + R_block)[:, None] * stride_la_r + + K_block[None, :] * stride_la_k + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + + # Accumulators + dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) + dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE) + + iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(iters): + M_idx = start_idx + i * BLOCK_M + M_block + M_mask = M_idx < end_idx + + # Load X: [BLOCK_M, BLOCK_K] + X_blk_ptrs = ( + X_ptr + M_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + ) + x = tl.load(X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0) + + # Load dY: [BLOCK_M, BLOCK_N] + DY_blk_ptrs = ( + DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + dy = tl.load(DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0) + + # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] + xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) + + # dY @ B[e]: [M, N] @ [N, R] -> [M, R] + dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32) + + # Cast intermediates to input dtype for subsequent tl.dot calls + # (tl.dot requires both operands to have the same dtype) + dy_b_cast = dy_b.to(x.dtype) + xa_cast = xa.to(dy.dtype) + + # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] + dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) + + # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R] + dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32) + + # Store dA with scaling (atomic add since multiple N_blocks contribute) + # Only store the actual R rows, not the padded ones + DLA_blk_ptrs = ( + DLA_ptr + + (lora_offset + R_block)[:, None] * stride_dla_r + + K_block[None, :] * stride_dla_k + ) + tl.atomic_add( + DLA_blk_ptrs, + (dA_acc * scaling).to(DLA_ptr.dtype.element_ty), + mask=R_mask[:, None] & K_mask[None, :], + ) + + # Store dB with scaling (atomic add since multiple K_blocks contribute) + DLB_blk_ptrs = ( + DLB_ptr + + N_block[:, None] * stride_dlb_n + + (lora_offset + R_block)[None, :] * stride_dlb_r + ) + tl.atomic_add( + DLB_blk_ptrs, + (dB_acc * scaling).to(DLB_ptr.dtype.element_ty), + mask=N_mask[:, None] & R_mask[None, :], + ) + + +def group_bwd_lora( + DY: torch.Tensor, + X: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + scaling: float, + sorted_scattered_idxs: Optional[torch.Tensor] = None, + k: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute LoRA gradients for A and B on expert-grouped data. + + Args: + DY: Gradient w.r.t. output [M_total, N] (grouped by expert) + X: Input [M_total, K] (grouped by expert) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + expert_offsets: Cumulative token counts per expert [E] + E: Number of experts + scaling: LoRA scaling factor + + Returns: + dA: Gradient for A [r*E, K] + dB: Gradient for B [N, r*E] + """ + R = lora_A.size(0) // E + K = X.size(1) + N = DY.size(1) + + # Zero-init for atomic accumulation + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + + BLOCK_R = _block_r_for_rank(R) + + def grid(META): + return ( + E * triton.cdiv(K, META["BLOCK_K"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + + _group_bwd_lora[grid]( + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + lora_A, + lora_A.stride(0), + lora_A.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + dA, + dA.stride(0), + dA.stride(1), + dB, + dB.stride(0), + dB.stride(1), + expert_offsets, + M=DY.size(0), + K=K, + N=N, + ACTUAL_R=R, # True LoRA rank + BLOCK_R=BLOCK_R, # Padded tile size + scaling=scaling, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + return dA, dB + + +# ============================================================================= +# Backward Kernel: Fused gather + LoRA gradient (dA, dB) — eliminates group() +# ============================================================================= + + +@triton.autotune( + configs=_group_bwd_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_bwd_lora_configs}, + reset_to_zero=["DLA_ptr", "DLB_ptr"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _group_bwd_lora_fused( + # Inputs (ungrouped) + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, + # Scatter indices for gather-on-load + sorted_scattered_idxs_ptr, + FAN_OUT: tl.constexpr, + # LoRA weights (needed for cross-terms) + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Gradient outputs + DLA_ptr, + stride_dla_r, + stride_dla_k, + DLB_ptr, + stride_dlb_n, + stride_dlb_r, + # Expert offsets + expert_offsets_ptr, + # Real expert offsets (for M_mask when using token rounding, else same as expert_offsets_ptr) + real_expert_offsets_ptr, + # Dimensions + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, + BLOCK_R: tl.constexpr, + scaling, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Fused gather + LoRA gradient computation. Same as _group_bwd_lora but + reads X and DY from ungrouped buffers using sorted_scattered_idxs for + indirect indexing, eliminating the need for separate group() calls. + + Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N)) + + For expert e: + dA[e] = scaling * (dY @ B[e])^T @ X -> [r, K] + dB[e] = scaling * dY^T @ (X @ A[e]^T) -> [N, r] + + Supports token rounding: expert_offsets_ptr gives the iteration range + (padded to BLOCK_M multiples), real_expert_offsets_ptr gives the real + token count for M_mask (to exclude padding tokens). + + Key difference from _group_bwd_lora: + Instead of X_ptr[M_idx, :] and DY_ptr[M_idx, :] on pre-grouped data, + we load scatter_idx = sorted_scattered_idxs[M_idx], then: + X_token_idx = scatter_idx // FAN_OUT (X is [M, K], not expanded) + DY uses scatter_idx directly (DY is [M*k, N] or expanded via gate) + """ + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + # Get expert's token range from cumulative offsets + # start_idx/end_idx from expert_offsets_ptr: iteration range (possibly padded) + # real_end_idx from real_expert_offsets_ptr: for M_mask (real token count) + if E_idx == 0: + start_idx = 0 + real_start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + real_start_idx = tl.load(real_expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + real_end_idx = tl.load(real_expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + + lora_offset = E_idx * ACTUAL_R + + # Load B[e] and A[e] — same as non-fused kernel + B_blk_ptrs = ( + LB_ptr + + N_block[:, None] * stride_lb_n + + (lora_offset + R_block)[None, :] * stride_lb_r + ) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0) + + A_blk_ptrs = ( + LA_ptr + + (lora_offset + R_block)[:, None] * stride_la_r + + K_block[None, :] * stride_la_k + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + + # Accumulators + dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) + dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE) + + real_num_tokens = real_end_idx - real_start_idx + iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(iters): + M_idx = start_idx + i * BLOCK_M + M_block + # Use real token count for masking (excludes padding tokens) + M_local = i * BLOCK_M + M_block + M_mask = M_local < real_num_tokens + + # Fused gather: load scatter indices, then indirect-load X and DY + scatter_idx = tl.load( + sorted_scattered_idxs_ptr + M_idx, mask=M_mask, other=0 + ).to(tl.int32) + X_token_idx = scatter_idx // FAN_OUT # X is [M, K], not expanded by k + + # Load X via indirect index: [BLOCK_M, BLOCK_K] + X_blk_ptrs = ( + X_ptr + X_token_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + ) + x = tl.load(X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0) + + # Load DY via scatter index: DY is [M*k, N] + DY_blk_ptrs = ( + DY_ptr + + scatter_idx[:, None] * stride_dym + + N_block[None, :] * stride_dyn + ) + dy = tl.load(DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0) + + # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] + xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) + + # dY @ B[e]: [M, N] @ [N, R] -> [M, R] + dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32) + + dy_b_cast = dy_b.to(x.dtype) + xa_cast = xa.to(dy.dtype) + + # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] + dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) + + # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R] + dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32) + + # Store dA with scaling (atomic add since multiple N_blocks contribute) + DLA_blk_ptrs = ( + DLA_ptr + + (lora_offset + R_block)[:, None] * stride_dla_r + + K_block[None, :] * stride_dla_k + ) + tl.atomic_add( + DLA_blk_ptrs, + (dA_acc * scaling).to(DLA_ptr.dtype.element_ty), + mask=R_mask[:, None] & K_mask[None, :], + ) + + # Store dB with scaling (atomic add since multiple K_blocks contribute) + DLB_blk_ptrs = ( + DLB_ptr + + N_block[:, None] * stride_dlb_n + + (lora_offset + R_block)[None, :] * stride_dlb_r + ) + tl.atomic_add( + DLB_blk_ptrs, + (dB_acc * scaling).to(DLB_ptr.dtype.element_ty), + mask=N_mask[:, None] & R_mask[None, :], + ) + + +def group_bwd_lora_fused( + DY: torch.Tensor, + X: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + E: int, + k: int, + scaling: float, + real_expert_offsets: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused gather + LoRA gradient computation. Same result as + group(X) + group(DY) + group_bwd_lora(DY, X, ...) but without + the intermediate grouped buffers. + + Args: + DY: Gradient w.r.t. output [M*k, N] (ungrouped, original token order) + X: Input [M, K] (ungrouped, original token order) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + expert_offsets: Cumulative token counts per expert [E] + (or padded offsets if using token rounding) + sorted_scattered_idxs: Maps grouped position -> original position [M*k] + (or padded version if using token rounding) + E: Number of experts + k: Fan-out (top-k) + scaling: LoRA scaling factor + real_expert_offsets: Original cumulative counts for M_mask when using + token rounding. If None, expert_offsets is used for both. + + Returns: + dA: Gradient for A [r*E, K] + dB: Gradient for B [N, r*E] + """ + R = lora_A.size(0) // E + K = X.size(1) + N = DY.size(1) + + # Zero-init for atomic accumulation + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + + BLOCK_R = _block_r_for_rank(R) + + if real_expert_offsets is None: + real_expert_offsets = expert_offsets + + def grid(META): + return ( + E * triton.cdiv(K, META["BLOCK_K"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + + _group_bwd_lora_fused[grid]( + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + sorted_scattered_idxs, + FAN_OUT=k, + LA_ptr=lora_A, + stride_la_r=lora_A.stride(0), + stride_la_k=lora_A.stride(1), + LB_ptr=lora_B, + stride_lb_n=lora_B.stride(0), + stride_lb_r=lora_B.stride(1), + DLA_ptr=dA, + stride_dla_r=dA.stride(0), + stride_dla_k=dA.stride(1), + DLB_ptr=dB, + stride_dlb_n=dB.stride(0), + stride_dlb_r=dB.stride(1), + expert_offsets_ptr=expert_offsets, + real_expert_offsets_ptr=real_expert_offsets, + M=sorted_scattered_idxs.size(0), + K=K, + N=N, + ACTUAL_R=R, + BLOCK_R=BLOCK_R, + scaling=scaling, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + return dA, dB diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py new file mode 100644 index 0000000000..6850dc6b40 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +from typing import Optional + +import torch +import triton +import triton.language as tl + +BLOCK_M = 128 +ALLOW_TF32 = True + + +@triton.jit +def _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=True, +): + K_block = tl.arange(0, BLOCK_K) + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = ( + W_ptr + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + + E_idx * stride_we + ) + iters = tl.cdiv(K, BLOCK_K) + + for K_block_id in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) + return acc + + +def _scatter2scatter_configs(): + return [ + triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4), + ] + + +@triton.autotune( + configs=_scatter2scatter_configs(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter( + X_ptr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + Y_ptr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + B_ptr, + stride_be: tl.constexpr, + stride_bn: tl.constexpr, + grouped_idx_ptr, + expert_idxs_ptr, + # block_start_idx_ptr, + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + # OUT_M, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = K % BLOCK_K == 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + E_M_idx = M_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = E_M_idx // FAN_OUT + acc = _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=allow_tf32, + ) + + if B_ptr is not None: + B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn + acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter( + X, + W, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + b=None, + x_grouped=False, + y_grouped=False, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + output = out + + scatter2scatter_compileable( + output, + W, + X, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + b, + x_grouped, + y_grouped, + ) + return output + + +@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"}) +def scatter2scatter_compileable( + output: torch.Tensor, + W: torch.Tensor, + X: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + b: Optional[torch.Tensor], + x_grouped: bool, + y_grouped: bool, +) -> None: + def grid(META): + grid_num = ( + triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) + * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + if b is None: + b = None + stride_be = stride_bk = 0 + else: + stride_be, stride_bk = b.stride() + + _scatter2scatter[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + output, + output.stride(0), + output.stride(1), + # B_ptr, stride_be, stride_bk + b, + stride_be, + stride_bk, + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + # block_start_idx_ptr=padded_block_idxs, + FAN_OUT=k, + M=X.size(0), + K=X.size(1), + N=output.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _config_XtY(): + return [ + triton.Config( + {"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4 + ), + ] + + +def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): + DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) + DW = DWt.permute(0, 2, 1) + if has_bias: + Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype) + else: + Db = None + groupXtY_compileable(E, DW, Db, DY, X, expert_offsets) + return DW, Db + + +@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) +def groupXtY_compileable( + E: int, + DW: torch.Tensor, + Db: Optional[torch.Tensor], + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor, +) -> None: + def grid(META): + grid = ( + E * triton.cdiv(META["K"], META["BLOCK_K"]), + triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid + + if Db is None: + stride_dbe = 0 + stride_dbn = 0 + else: + stride_dbe, stride_dbn = Db.stride() + + _groupXtY[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # Db_ptr, stride_dwe, stride_dbn, + Db, + stride_dbe, + stride_dbn, + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + M=DY.size(0), + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + +@triton.autotune( + configs=_config_XtY(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _groupXtY( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + expert_offsets_ptr, + M, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = ( + DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + ) + if (Db_ptr is not None) and (K_block_id == 0): + _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias=True, + ) + else: + _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias=False, + ) + + +@triton.jit +def _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias: tl.constexpr, +): + if compute_bias: + db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE) + else: + db_acc = None + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + if NO_K_MASK: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + if NO_N_MASK: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + + if compute_bias: + db_acc += tl.sum(dy, axis=0) + + DW_blk_ptrs = ( + DW_ptr + + E_idx * stride_dwe + + K_block[:, None] * stride_dwk + + N_block[None, :] * stride_dwn + ) + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + if compute_bias: + Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn + tl.store(Db_blk_ptrs, db_acc, mask=N_mask) + + +def _config_grouping(): + return [ + triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + + +def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + if out is not None: + Y = out + else: + Y = torch.empty((N, K), dtype=A.dtype, device=A.device) + group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) + return Y + + +@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) +def group_compileable( + A: torch.Tensor, + K: int, + N: int, + Y: torch.Tensor, + coeff: torch.Tensor, + has_coeff: bool, + fan_out: int, + sorted_expert_idxs: torch.Tensor, +) -> None: + def grid(META): + grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),) + return grid_num + + _group[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + has_coeff, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + Y, + Y.stride(0), + Y.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +@triton.autotune(configs=_config_grouping(), key=["K"]) +@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0}) +@triton.jit +def _group( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT: tl.constexpr, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + NO_K_MASK: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = ( + src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + ) + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + for i in range(0, iters): + if NO_K_MASK or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=mask) + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py new file mode 100644 index 0000000000..20c0dcf183 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _single2scatter( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + expert_idxs_ptr, + FAN_OUT: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + N_block_id = pid0 + if FAN_OUT == 1: + in_idx = pid1 + else: + in_idx = 0 + out_idx = pid1 + + K_block = tl.arange(0, BLOCK_K) + N_block = tl.max_contiguous( + tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), + BLOCK_N, + ) + E_idx = tl.load(expert_idxs_ptr + pid1) + X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk + W_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + ) + acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) + for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load(X_blk_ptrs) + w = tl.load(W_blk_ptrs) + acc += tl.sum(x * w, axis=0)[None, :] + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn + tl.store(Y_blk_ptrs, acc) + + +def single2scatter(X, W, expert_idxs): + E, xdim, ydim = W.size() + k = expert_idxs.size(1) + assert X.size(0) == k or X.size(0) == 1 + Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) + BLOCK_N = 128 + BLOCK_K = 128 + grid = ydim // BLOCK_N, k + _single2scatter[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + Y, + Y.stride(0), + Y.stride(1), + expert_idxs, + FAN_OUT=Y.size(0) // X.size(0), + K=xdim, + N=ydim, + E=E, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=tl.float32, + ) + return Y diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py new file mode 100644 index 0000000000..330e133eaa --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors +# Adapted from https://github.com/shawntan/scattermoe +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE +# +# Modifications and LoRA adaptation Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ScatterMoE layer replacements for HuggingFace MoE architectures. + +Provides drop-in forward replacements that use ScatterMoE kernels for +acceleration. When used via the HF ``kernels`` library +(``replace_kernel_forward_from_hub``), these classes replace the forward +method of the original MoE block. + +LoRA support +------------ +When peft wraps parameters via ``target_parameters``, the ``self.experts`` +submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate`` +router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP`` +forward detects this and automatically: + +1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta +2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module +3. Extracts LoRA A/B weights and scaling from each wrapper +4. Converts B layout from peft rank-major to scattermoe expert-major +5. Routes to ``parallel_linear_lora`` for fused LoRA computation +6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate`` + (peft wraps their linear layers with standard LoRA, no special handling) +""" + +import torch +from torch import nn +from torch.nn import functional as F + +from .parallel_experts import flatten_sort_count, parallel_linear +from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora + +# ============================================================================= +# LoRA layout conversion utilities (peft <-> scattermoe) +# ============================================================================= + + +def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): + """Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe + expert-major ``[N, r*E]``. + + peft reshapes B to ``[out, r, E]`` (rank-major). + scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major). + """ + N = peft_B.shape[0] + return ( + peft_B.reshape(N, rank, num_experts) + .permute(0, 2, 1) + .contiguous() + .reshape(N, num_experts * rank) + ) + + +def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Convert peft LoRA weights for **down_proj** to scattermoe layout. + + down_proj param: ``[E, hidden, inter]`` + scattermoe W = param.T = ``[E, inter, hidden]`` -> K=inter, N=hidden + + peft assigns: in_features=hidden (dim 1), out_features=inter (dim 2). + peft lora_A: ``[r*E, hidden]``, lora_B: ``[inter, r*E]`` + + scattermoe needs: lora_A ``[r*E, K=inter]``, lora_B ``[N=hidden, r*E]`` + + The roles of A and B are swapped because peft operates in the parameter's + native layout while scattermoe uses the transposed view. + """ + peft_B_expert_major = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) + + K_inter = peft_B.shape[0] # inter + N_hidden = peft_A.shape[1] # hidden + smoe_A = torch.zeros( + rank * num_experts, K_inter, device=peft_A.device, dtype=peft_A.dtype + ) + smoe_B = torch.zeros( + N_hidden, rank * num_experts, device=peft_A.device, dtype=peft_A.dtype + ) + for e in range(num_experts): + s = e * rank + # peft's A_e: [r, hidden], B_e: [inter, r] + A_e = peft_A[s : s + rank, :] # [r, hidden] + B_e = peft_B_expert_major[:, s : s + rank] # [inter, r] + smoe_A[s : s + rank, :] = B_e.T # [r, inter] + smoe_B[:, s : s + rank] = A_e.T # [hidden, r] + return smoe_A, smoe_B + + +# ============================================================================= +# ParamWrapper unwrapping +# ============================================================================= + + +def _unwrap_gate_lora(gate_module): + """Unwrap peft ``ParamWrapper`` on the router gate. + + When peft targets ``gate.weight``, ``self.gate`` becomes:: + + ParamWrapper(weight) + -> base_layer: OlmoeTopKRouter (the real module) + + This function detects the wrapping and returns the base router plus + the LoRA-adjusted weight tensor. + + Returns: + (base_gate, gate_weight) + + ``base_gate`` is the original router module (with ``.top_k``, + ``.num_experts``, ``.norm_topk_prob``). + ``gate_weight`` is the router weight with any LoRA delta applied. + """ + if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"): + base_gate = gate_module.base_layer + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module) + if lora_A is not None: + # gate weight: [num_experts, hidden_size] + # lora_A: [r, hidden_size], lora_B: [num_experts, r] + # delta = scaling * B @ A = [num_experts, hidden_size] + delta = scaling * (lora_B @ lora_A) + gate_weight = base_gate.weight + delta + else: + gate_weight = base_gate.weight + return base_gate, gate_weight + else: + # No wrapping — gate is the original module + return gate_module, gate_module.weight + + +def _unwrap_experts_lora(experts_module): + """Walk a peft ``ParamWrapper`` chain on ``self.experts``. + + When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via + ``target_parameters``, ``self.experts`` becomes a nested chain:: + + ParamWrapper(down_proj) + -> base_layer: ParamWrapper(gate_up_proj) + -> base_layer: OlmoeExperts (the real module) + + This function walks the chain, collects LoRA params keyed by + ``parameter_name``, and returns the base experts module. + + Returns: + (base_experts, gup_lora, down_lora) + + Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``. + A/B are already in scattermoe layout. + """ + # Collect ParamWrapper layers by their parameter_name + wrappers = {} + module = experts_module + while hasattr(module, "base_layer") and hasattr(module, "lora_A"): + param_name = getattr(module, "parameter_name", None) + if param_name is not None: + wrappers[param_name] = module + module = module.base_layer + + base_experts = module + + if not wrappers: + return base_experts, None, None + + # Determine num_experts from base module + num_experts = getattr(base_experts, "num_experts", None) + if num_experts is None: + # Fallback: infer from parameter shape + gup = getattr(base_experts, "gate_up_proj", None) + if gup is not None: + num_experts = gup.shape[0] + + # Extract gate_up_proj LoRA + gup_lora = None + gup_wrapper = wrappers.get("gate_up_proj") + if gup_wrapper is not None: + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) + if lora_A is not None: + rank = lora_A.shape[0] // num_experts + smoe_A = lora_A # already expert-major [r*E, K] + smoe_B = peft_lora_B_to_scattermoe(lora_B, num_experts, rank) + gup_lora = (smoe_A, smoe_B, scaling) + + # Extract down_proj LoRA (needs A<->B swap due to transposition) + down_lora = None + down_wrapper = wrappers.get("down_proj") + if down_wrapper is not None: + lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper) + if lora_A is not None: + rank = lora_A.shape[0] // num_experts + smoe_A, smoe_B = peft_down_proj_lora_to_scattermoe( + lora_A, lora_B, num_experts, rank + ) + down_lora = (smoe_A, smoe_B, scaling) + + return base_experts, gup_lora, down_lora + + +# ============================================================================= +# Layer classes +# ============================================================================= + + +class ScatterMoEGatedMLP(nn.Module): + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + # compute the top_k routing decision + router_logits = self.router.layer(layer_input) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.router.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(layer_input.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + selected_experts, num_experts=self.router.num_experts + ) + + # compute experts + gates, h = parallel_linear( + layer_input, + self.input_linear.weight.transpose(2, 1), + self.router.top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ).chunk(2, dim=-1) + h = self.activation(gates) * h + layer_output = parallel_linear( + h, + self.output_linear.weight.transpose(2, 1), + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + layer_output = layer_output.view(bsz, length, emb_size) + return layer_output + + +class HFScatterMoEGatedMLP(nn.Module): + """ + ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE). + + Used as a kernel layer via the HF ``kernels`` library. The ``forward`` + method replaces the original ``OlmoeSparseMoeBlock.forward``. + + Supports both full-parameter training and LoRA fine-tuning: + + * **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel) + * **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts + adapter weights, and uses ``parallel_linear_lora`` (fused kernel) + """ + + @staticmethod + def forward(self: nn.Module, layer_input: torch.Tensor): + """ + Forward pass using ScatterMoE kernels. + + Args: + self: The MoeSparseMoeBlock module containing: + - self.gate: Router (or peft ParamWrapper wrapping it) + - self.experts: Experts module (or peft ParamWrapper chain) + - self.shared_expert: Optional shared expert (e.g. Qwen2MoE) + - self.shared_expert_gate: Optional shared expert gate + layer_input: Input tensor [batch_size, seq_len, hidden_size] + + Returns: + Tensor: [batch_size, seq_len, hidden_size] + """ + batch_size, sequence_length, hidden_dim = layer_input.shape + hidden_states_flat = layer_input.view(-1, hidden_dim) + + # ==================================================================== + # Shared Expert (if present, e.g. Qwen2MoE) + # ==================================================================== + # peft wraps individual linear layers inside shared_expert with + # standard LoRA — calling forward() handles this transparently. + if hasattr(self, "shared_expert") and self.shared_expert is not None: + shared_expert_output = self.shared_expert(hidden_states_flat) + # shared_expert_gate may also be peft-wrapped (standard LoRA + # on nn.Linear), its forward() applies LoRA automatically. + shared_expert_gate_output = F.sigmoid( + self.shared_expert_gate(hidden_states_flat) + ) + shared_expert_output = shared_expert_output * shared_expert_gate_output + else: + shared_expert_output = None + + # ==================================================================== + # Router Computation (with optional gate LoRA) + # ==================================================================== + base_gate, gate_weight = _unwrap_gate_lora(self.gate) + router_logits = F.linear(hidden_states_flat, gate_weight) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + top_k = base_gate.top_k + num_experts = base_gate.num_experts + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + if base_gate.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + selected_experts, num_experts=num_experts + ) + + # ==================================================================== + # Detect LoRA (peft ParamWrapper) and extract adapter weights + # ==================================================================== + experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + + # ==================================================================== + # Gate + Up projection + # ==================================================================== + gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + + if gup_lora is not None: + gup_A, gup_B, gup_scaling = gup_lora + gup = parallel_linear_lora( + hidden_states_flat, + gate_up_W, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=gup_A, + lora_B=gup_B, + scaling=gup_scaling, + grouped_in=False, + grouped_out=True, + ) + else: + gup = parallel_linear( + hidden_states_flat, + gate_up_W, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ) + + gates, h = gup.chunk(2, dim=-1) + h = experts.act_fn(gates) * h + + # ==================================================================== + # Down projection + # ==================================================================== + down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + + if down_lora is not None: + down_A, down_B, down_scaling = down_lora + expert_output = parallel_linear_lora( + h, + down_W, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=down_A, + lora_B=down_B, + scaling=down_scaling, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + else: + expert_output = parallel_linear( + h, + down_W, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + # ==================================================================== + # Combine with shared expert and reshape + # ==================================================================== + if shared_expert_output is not None: + expert_output = expert_output + shared_expert_output + + expert_output = expert_output.view(batch_size, sequence_length, hidden_dim) + return expert_output diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py new file mode 100644 index 0000000000..e4b40660ab --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ParallelExperts module with LoRA support. + +Provides a drop-in replacement for ScatterMoE's ParallelExperts that +uses the fused LoRA kernel when adapter weights are attached. +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from .parallel_linear_lora import parallel_linear_lora + + +class ParallelExperts(nn.Module): + """ + Parallel Experts with fused LoRA support. + + Drop-in replacement for the original ParallelExperts. When LoRA parameters + are attached via set_lora(), the forward pass uses a fused kernel: + Y = X @ W + scaling * (X @ A^T) @ B^T + """ + + def __init__( + self, + num_experts: int, + input_size: int, + output_size: int, + bias: bool = False, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + if bias: + self.bias = nn.Parameter(torch.empty(num_experts, output_size)) + else: + self.bias = None + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self._lora_A: torch.Tensor | None = None + self._lora_B: torch.Tensor | None = None + self._lora_scaling: float | None = None + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def extra_repr(self) -> str: + return ( + f"num_experts={self.num_experts}, " + f"input_size={self.input_size}, " + f"output_size={self.output_size}" + ) + + def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float): + """Attach LoRA parameters for fused computation.""" + self._lora_A = lora_A + self._lora_B = lora_B + self._lora_scaling = scaling + + def clear_lora(self): + """Remove LoRA parameters.""" + self._lora_A = None + self._lora_B = None + self._lora_scaling = None + + def forward( + self, + inputs: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + ) -> torch.Tensor: + return parallel_linear_lora( + inputs, + self.weight.permute(0, 2, 1), # [E, input, output] + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=self._lora_A, + lora_B=self._lora_B, + scaling=self._lora_scaling or 1.0, + expert_biases=self.bias, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py new file mode 100644 index 0000000000..2aae050bb6 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +from typing import Optional + +import torch +import torch.nn as nn + +from . import kernels + + +@torch.library.custom_op("scattermoe::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +@torch.compile +def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int): + with torch.no_grad(): + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + expert_counts = compileable_bincount( + flattened_expert_idxs, minlength=num_experts + ) + expert_offsets = expert_counts.cumsum(-1) + return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets + + +class ParallelLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + ): + with torch.device(x.device): + output = kernels.ops.scatter2scatter( + X=x, + W=expert_weights, + b=expert_biases, + k=k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + if gates is not None: + output_expanded = output.view( + gates.size(0), gates.size(1), output.size(-1) + ) + output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, + expert_weights, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + with torch.device(grad_out.device): + ( + x, + expert_weights, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + # print("backward") + + if gates is not None: + # calculate gates gradient + # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) + d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + grouped_grad_out = output_expanded.flatten( + 0, 1 + ) # reuse expanded buffer later + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = kernels.ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x + + d_weights, d_biases = kernels.ops.group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + E=expert_weights.size(0), + has_bias=expert_biases is not None, + ) + + d_expanded_input = kernels.ops.scatter2scatter( + X=grouped_grad_out, + x_grouped=True, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input, # Reuse grouped_x buffer + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view( + x.size(0), k, d_expanded_input.size(-1) + ).sum(-2) + # print("backward end.") + return ( + # x, expert_weights, + d_input, + d_weights, + # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, + None, + None, + None, + None, + # bias, gates + d_biases, + d_gates, + # grouped_in, grouped_out, + None, + None, + ) + + +def parallel_linear( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases=None, + gates=None, + grouped_in=False, + grouped_out=False, +): + results = ParallelLinear.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases, + gates, + grouped_in, + grouped_out, + ) + return results + + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size, bias=False) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + + if bias: + self.bias = nn.Parameter(torch.empty(num_experts, output_size)) + else: + self.bias = None + + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self.reset_parameters() + + def extra_repr(self): + return "num_experts={}, input_size={}, output_size={}".format( + self.num_experts, self.input_size, self.output_size + ) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward( + self, + inputs, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + results = parallel_linear( + inputs, + self.weight.permute(0, 2, 1), + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases=self.bias, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) + return results diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py new file mode 100644 index 0000000000..cedaa671f6 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -0,0 +1,474 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ScatterMoE + LoRA Autograd Function +==================================== + +Provides the autograd function and Python interface for fused ScatterMoE + LoRA. + +Key design for LoRA training: + - Expert weights W are FROZEN (no gradient computed for W). + - Only LoRA adapter weights (A, B) receive gradients. + - The input gradient dX is still computed (needed for upstream layers). + - This avoids the expensive group_bwd_W computation entirely. + +Forward: + Y = X @ W + scaling * (X @ A^T) @ B^T + +Backward (W frozen): + dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA) + dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data) + dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data) +""" + +from typing import Optional + +import torch + +from .kernels import ops as base_ops +from .kernels.lora_ops import ( + group_bwd_lora, + group_bwd_lora_fused, + scatter2scatter_lora, + scatter2scatter_lora_dX, +) + + +class ScatterMoELoRA(torch.autograd.Function): + """ + Autograd function for fused ScatterMoE + LoRA with frozen expert weights. + + This function is optimized for the LoRA fine-tuning scenario where: + - Expert weights W are frozen (requires_grad=False) + - Only LoRA A and B matrices receive gradients + - Input gradients are computed for upstream layer backprop + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + use_fused_dX: bool = False, + use_fused_gather: bool = False, + ): + with torch.device(x.device): + # Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T + output = scatter2scatter_lora( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + b=expert_biases, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + # Handle gating (weighted combination of top-k expert outputs) + if gates is not None: + output_expanded = output.view( + gates.size(0), gates.size(1), output.size(-1) + ) + output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, + expert_weights, + lora_A, + lora_B, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + ctx.scaling = scaling + ctx.use_fused_dX = use_fused_dX + ctx.use_fused_gather = use_fused_gather + + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + with torch.device(grad_out.device): + ( + x, + expert_weights, + lora_A, + lora_B, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + + k = ctx.k + scaling = ctx.scaling + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + E = expert_weights.size(0) + + # ------------------------------------------------------------------ + # Gate gradients (if using top-k gating with routing weights) + # ------------------------------------------------------------------ + if gates is not None: + # d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :] + d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # Reuse output_expanded buffer for grouped_grad_out + grouped_grad_out = output_expanded.flatten(0, 1) + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + # Fused gather is only possible when data isn't already grouped + # and there are no gate coefficients (which require a multiplicative gather). + # + # Heuristic: fused gather eliminates group() calls but uses random + # access (via sorted_scattered_idxs) in the inner GEMM loop. At + # large problem sizes the GEMM dominates runtime and sequential + # access from group() is faster than the random scatter loads. + # Disable fused gather when the total workload exceeds a threshold. + M_total = sorted_scattered_idxs.size(0) + K_dim = x.size(-1) + N_dim = expert_weights.size(-1) + fuse_gather_workload = M_total * max(K_dim, N_dim) + _FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements + + can_fuse_gather = ( + ctx.use_fused_gather + and not grouped_in + and not grouped_out + and gates is None + and fuse_gather_workload < _FUSE_GATHER_THRESHOLD + ) + + if can_fuse_gather: + # ------------------------------------------------------------------ + # Fused path: skip BOTH group() calls entirely + # ------------------------------------------------------------------ + # group_bwd_lora_fused reads ungrouped DY and X directly + # scatter2scatter_lora_dX (if used) reads ungrouped DY via scatter pattern + d_expanded_input = None + + d_lora_A, d_lora_B = group_bwd_lora_fused( + DY=grad_out, + X=x, + lora_A=lora_A, + lora_B=lora_B, + expert_offsets=expert_offsets, + sorted_scattered_idxs=sorted_scattered_idxs, + E=E, + k=k, + scaling=scaling, + ) + + # When using fused gather, we need grouped_grad_out only if + # the dX path is NOT fused (original path needs it grouped). + # If fused dX is also enabled, it can read ungrouped DY directly. + if not ctx.use_fused_dX: + grouped_grad_out = base_ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + else: + # ------------------------------------------------------------------ + # Original path: explicit group() calls + # ------------------------------------------------------------------ + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = base_ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x # Will be overwritten; reuse buffer + + d_lora_A, d_lora_B = group_bwd_lora( + DY=grouped_grad_out, + X=grouped_x, + lora_A=lora_A, + lora_B=lora_B, + expert_offsets=expert_offsets, + E=E, + scaling=scaling, + ) + + # ------------------------------------------------------------------ + # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A + # ------------------------------------------------------------------ + if ctx.use_fused_dX: + if can_fuse_gather: + # Fully fused: read ungrouped DY via scatter pattern + d_expanded_input = scatter2scatter_lora_dX( + DY=grad_out, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + dy_grouped=False, + dx_grouped=grouped_in, + out=d_expanded_input, + ) + else: + # Fused dX only: read from pre-grouped DY + d_expanded_input = scatter2scatter_lora_dX( + DY=grouped_grad_out, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + dy_grouped=True, + dx_grouped=grouped_in, + out=d_expanded_input, + ) + else: + # Original path: separate base scatter2scatter + LoRA Python loop + d_expanded_input = base_ops.scatter2scatter( + X=grouped_grad_out, + x_grouped=True, + W=expert_weights.permute(0, 2, 1), # [E, N, K] + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input, + ) + + # LoRA part: dX_lora = scaling * (dY @ B) @ A + if scaling != 0.0: + d_input_lora_grouped = _compute_lora_input_grad( + grouped_grad_out, + lora_A, + lora_B, + expert_offsets, + E, + scaling, + ) + if grouped_in: + d_expanded_input = d_expanded_input + d_input_lora_grouped + else: + d_input_lora_ungrouped = torch.zeros_like(d_expanded_input) + d_input_lora_ungrouped[sorted_scattered_idxs] = ( + d_input_lora_grouped + ) + d_expanded_input = d_expanded_input + d_input_lora_ungrouped + + # Reduce over top-k if k > 1 + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view( + x.size(0), k, d_expanded_input.size(-1) + ).sum(-2) + + # W is frozen during LoRA training -- skip weight gradient + d_weights = ( + torch.zeros_like(expert_weights) + if expert_weights.requires_grad + else None + ) + d_biases = None + + return ( + d_input, + d_weights, + None, + None, + None, + None, # k, sorted indices, offsets + d_lora_A, + d_lora_B, + None, # lora_A, lora_B, scaling + d_biases, + d_gates, + None, + None, # grouped_in, grouped_out + None, # use_fused_dX + None, # use_fused_gather + ) + + +def _compute_lora_input_grad( + grouped_grad_out: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + scaling: float, +) -> torch.Tensor: + """ + Compute the LoRA contribution to the input gradient: + dX_lora = scaling * (dY @ B) @ A + + Uses PyTorch ops on expert-grouped data. + Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e + """ + R = lora_A.size(0) // E + K = lora_A.size(1) + M_total = grouped_grad_out.size(0) + + d_input_lora = torch.zeros( + (M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype + ) + + prev_offset = 0 + for e in range(E): + curr_offset = expert_offsets[e].item() + if curr_offset > prev_offset: + dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N] + a_e = lora_A[e * R : (e + 1) * R, :] # [r, K] + b_e = lora_B[:, e * R : (e + 1) * R] # [N, r] + + # dX_e = scaling * (dY_e @ B_e) @ A_e + dy_b = dy_e @ b_e # [M_e, r] + dx_e = scaling * (dy_b @ a_e) # [M_e, K] + d_input_lora[prev_offset:curr_offset] = dx_e + + prev_offset = curr_offset + + return d_input_lora + + +# ============================================================================= +# Helper: Extract LoRA params from PEFT ParamWrapper +# ============================================================================= + + +def get_lora_params_from_wrapper(module) -> tuple: + """ + Extract LoRA parameters from a PEFT ParamWrapper. + + Returns: + (lora_A, lora_B, scaling) if LoRA is active, else (None, None, None) + """ + if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"): + return None, None, None + + active_adapters = getattr(module, "active_adapters", ["default"]) + if not active_adapters: + return None, None, None + + adapter_name = active_adapters[0] + + lora_A_dict = getattr(module, "lora_A", {}) + lora_B_dict = getattr(module, "lora_B", {}) + scaling_dict = getattr(module, "scaling", {}) + + if adapter_name not in lora_A_dict: + return None, None, None + + lora_A = lora_A_dict[adapter_name].weight + lora_B = lora_B_dict[adapter_name].weight + scaling = scaling_dict[adapter_name] + + return lora_A, lora_B, scaling + + +# ============================================================================= +# Drop-in replacement for parallel_linear +# ============================================================================= + + +def parallel_linear_lora( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + lora_A: Optional[torch.Tensor] = None, + lora_B: Optional[torch.Tensor] = None, + scaling: float = 1.0, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + use_fused_dX: bool = False, + use_fused_gather: bool = False, +): + """ + Drop-in replacement for parallel_linear that supports LoRA. + + If lora_A and lora_B are provided, uses fused LoRA kernel. + Otherwise falls back to standard scatter2scatter. + """ + if lora_A is not None and lora_B is not None: + return ScatterMoELoRA.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A, + lora_B, + scaling, + expert_biases, + gates, + grouped_in, + grouped_out, + use_fused_dX, + use_fused_gather, + ) + else: + from .parallel_experts import ParallelLinear + + return ParallelLinear.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases, + gates, + grouped_in, + grouped_out, + ) diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index c7fb79ff64..9bcf3930b3 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,5 +1,8 @@ +from pathlib import Path + from kernels import ( - LayerRepository, + # LayerRepository, + LocalLayerRepository, Mode, register_kernel_mapping, replace_kernel_forward_from_hub, @@ -19,16 +22,19 @@ def pre_model_load(self, cfg): self._kernelize_model(cfg.model_config_type) def _register_kernels(self): + plugin_root = Path(__file__).parent register_kernel_mapping( { "HFScatterMoEParallelExperts": { "cuda": { - Mode.TRAINING: LayerRepository( - repo_id="axolotl-ai-co/scattermoe", + Mode.TRAINING: LocalLayerRepository( + repo_path=plugin_root / "libs" / "scattermoe_lora", + package_name="scattermoe_lora", layer_name="HFScatterMoEGatedMLP", ), - Mode.INFERENCE: LayerRepository( - repo_id="axolotl-ai-co/scattermoe", + Mode.INFERENCE: LocalLayerRepository( + repo_path=plugin_root / "libs" / "scattermoe_lora", + package_name="scattermoe_lora", layer_name="HFScatterMoEGatedMLP", ), }, From 512934e6e0b4882bcc502c8954c43450c1315893 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Feb 2026 14:26:14 -0500 Subject: [PATCH 2/8] fsdp, bf16, dim fixes --- .../libs/scattermoe_lora/kernels/lora_ops.py | 106 ++++++++++++------ .../kernels/libs/scattermoe_lora/layers.py | 78 +++++++------ .../scattermoe_lora/parallel_linear_lora.py | 6 +- src/axolotl/loaders/patch_manager.py | 2 +- 4 files changed, 120 insertions(+), 72 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 77a4952487..352676e378 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -269,24 +269,33 @@ def _compute_expert_block_lora( # Accumulator for X @ A^T: [BLOCK_M, BLOCK_R] xa_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + # Determine the input element type for consistent casting. + # Masked tl.load with other=0.0 can upcast bf16->fp32 in some Triton versions, + # causing dtype mismatches in tl.dot. We cast all tiles to the same type. + INPUT_DTYPE = X_ptr.dtype.element_ty + for i in range(iters): if no_k_mask: - x = tl.load(X_blk_ptrs, mask=E_mask[:, None], other=0.0) - w = tl.load(W_blk_ptrs, mask=N_mask[None, :], other=0.0) - a = tl.load( - A_blk_ptrs, mask=R_mask[:, None], other=0.0 - ) # [BLOCK_R, BLOCK_K], masked on R dim + x = tl.load(X_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :], other=0.0).to(INPUT_DTYPE) + a = tl.load(A_blk_ptrs, mask=R_mask[:, None], other=0.0).to(INPUT_DTYPE) else: K_mask = (i * BLOCK_K + K_block) < K - x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :], other=0.0) - w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0) - a = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + x = tl.load( + X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + w = tl.load( + W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + a = tl.load( + A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) # Base: acc += X @ W ([M, K] @ [K, N] -> [M, N]) - acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) + acc += tl.dot(x, w, allow_tf32=allow_tf32).to(tl.float32) # LoRA: xa_acc += X @ A^T ([M, K] @ [K, R] -> [M, R]) - xa_acc = tl.dot(x, tl.trans(a), xa_acc, allow_tf32=allow_tf32) + xa_acc += tl.dot(x, tl.trans(a), allow_tf32=allow_tf32).to(tl.float32) X_blk_ptrs += BLOCK_K * stride_xk W_blk_ptrs += BLOCK_K * stride_wk @@ -325,7 +334,7 @@ def _scatter2scatter_lora_configs(): BLOCK_N: {32, 64, 128, 256} BLOCK_K: {32, 64, 128} num_warps: {4, 8} - num_stages: {2, 3, 4, 5} + num_stages: {3, 4, 5} BLOCK_M is fixed at 128 (module-level constant, not autotuned in the scatter2scatter pattern). @@ -335,7 +344,7 @@ def _scatter2scatter_lora_configs(): [32, 64, 128, 256], # BLOCK_N [32, 64, 128], # BLOCK_K [4, 8], # num_warps - [2, 3, 4, 5], # num_stages + [3, 4, 5], # num_stages ): configs.append( triton.Config( @@ -714,28 +723,31 @@ def _compute_expert_block_lora_dX( # Accumulator for DY @ B: [BLOCK_M, BLOCK_R] dy_b_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + # Determine the input element type for consistent casting. + INPUT_DTYPE = DY_ptr.dtype.element_ty + for i in range(iters): if no_n_mask: - dy = tl.load(DY_blk_ptrs, mask=E_mask[:, None], other=0.0) - wt = tl.load(WT_blk_ptrs, mask=K_mask[None, :], other=0.0) - b = tl.load(B_blk_ptrs, mask=R_mask[None, :], other=0.0) + dy = tl.load(DY_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE) + wt = tl.load(WT_blk_ptrs, mask=K_mask[None, :], other=0.0).to(INPUT_DTYPE) + b = tl.load(B_blk_ptrs, mask=R_mask[None, :], other=0.0).to(INPUT_DTYPE) else: N_mask_iter = (i * BLOCK_N + N_block) < N dy = tl.load( DY_blk_ptrs, mask=E_mask[:, None] & N_mask_iter[None, :], other=0.0 - ) + ).to(INPUT_DTYPE) wt = tl.load( WT_blk_ptrs, mask=N_mask_iter[:, None] & K_mask[None, :], other=0.0 - ) + ).to(INPUT_DTYPE) b = tl.load( B_blk_ptrs, mask=N_mask_iter[:, None] & R_mask[None, :], other=0.0 - ) + ).to(INPUT_DTYPE) # Base: acc += DY @ W^T ([M, N] @ [N, K] -> [M, K]) - acc = tl.dot(dy, wt, acc, allow_tf32=allow_tf32) + acc += tl.dot(dy, wt, allow_tf32=allow_tf32).to(tl.float32) # LoRA: dy_b_acc += DY @ B ([M, N] @ [N, R] -> [M, R]) - dy_b_acc = tl.dot(dy, b, dy_b_acc, allow_tf32=allow_tf32) + dy_b_acc += tl.dot(dy, b, allow_tf32=allow_tf32).to(tl.float32) DY_blk_ptrs += BLOCK_N * stride_dyn WT_blk_ptrs += BLOCK_N * stride_wn @@ -774,14 +786,14 @@ def _scatter2scatter_lora_dX_configs(): BLOCK_K: {32, 64, 128, 256} (output tile) BLOCK_N: {32, 64, 128, 256} (reduction tile) num_warps: {4, 8} - num_stages: {2, 3, 4, 5} + num_stages: {3, 4, 5} """ configs = [] for block_k, block_n, warps, stages in product( [32, 64, 128, 256], # BLOCK_K (output dimension) [32, 64, 128, 256], # BLOCK_N (reduction dimension) [4, 8], # num_warps - [2, 3, 4, 5], # num_stages + [3, 4, 5], # num_stages ): configs.append( triton.Config( @@ -1083,7 +1095,7 @@ def _group_bwd_lora_configs(): BLOCK_K: {32, 64, 128, 256} BLOCK_N: {32, 64, 128, 256} num_warps: {4, 8} - num_stages: {2, 3, 4, 5} + num_stages: {3, 4, 5} The backward kernel also uses BLOCK_R (from LoRA rank), but that is determined by the rank and not autotunable. @@ -1094,7 +1106,7 @@ def _group_bwd_lora_configs(): [32, 64, 128, 256], # BLOCK_K [32, 64, 128, 256], # BLOCK_N [4, 8], # num_warps - [2, 3, 4, 5], # num_stages + [3, 4, 5], # num_stages ): configs.append( triton.Config( @@ -1227,13 +1239,18 @@ def _group_bwd_lora( lora_offset = E_idx * ACTUAL_R + # Determine input element type for consistent casting. + INPUT_DTYPE = X_ptr.dtype.element_ty + # Load B[e]: [BLOCK_N, BLOCK_R] (masked on R and N, other=0 for padding) B_blk_ptrs = ( LB_ptr + N_block[:, None] * stride_lb_n + (lora_offset + R_block)[None, :] * stride_lb_r ) - b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) # Load A[e]: [BLOCK_R, BLOCK_K] (masked on R and K, other=0 for padding) A_blk_ptrs = ( @@ -1241,7 +1258,9 @@ def _group_bwd_lora( + (lora_offset + R_block)[:, None] * stride_la_r + K_block[None, :] * stride_la_k ) - a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) # Accumulators dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) @@ -1256,13 +1275,17 @@ def _group_bwd_lora( X_blk_ptrs = ( X_ptr + M_idx[:, None] * stride_xm + K_block[None, :] * stride_xk ) - x = tl.load(X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0) + x = tl.load( + X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) # Load dY: [BLOCK_M, BLOCK_N] DY_blk_ptrs = ( DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn ) - dy = tl.load(DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0) + dy = tl.load( + DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) @@ -1272,8 +1295,8 @@ def _group_bwd_lora( # Cast intermediates to input dtype for subsequent tl.dot calls # (tl.dot requires both operands to have the same dtype) - dy_b_cast = dy_b.to(x.dtype) - xa_cast = xa.to(dy.dtype) + dy_b_cast = dy_b.to(INPUT_DTYPE) + xa_cast = xa.to(INPUT_DTYPE) # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) @@ -1499,20 +1522,27 @@ def _group_bwd_lora_fused( lora_offset = E_idx * ACTUAL_R + # Determine input element type for consistent casting. + INPUT_DTYPE = X_ptr.dtype.element_ty + # Load B[e] and A[e] — same as non-fused kernel B_blk_ptrs = ( LB_ptr + N_block[:, None] * stride_lb_n + (lora_offset + R_block)[None, :] * stride_lb_r ) - b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) A_blk_ptrs = ( LA_ptr + (lora_offset + R_block)[:, None] * stride_la_r + K_block[None, :] * stride_la_k ) - a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) # Accumulators dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) @@ -1536,7 +1566,9 @@ def _group_bwd_lora_fused( X_blk_ptrs = ( X_ptr + X_token_idx[:, None] * stride_xm + K_block[None, :] * stride_xk ) - x = tl.load(X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0) + x = tl.load( + X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) # Load DY via scatter index: DY is [M*k, N] DY_blk_ptrs = ( @@ -1544,7 +1576,9 @@ def _group_bwd_lora_fused( + scatter_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn ) - dy = tl.load(DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0) + dy = tl.load( + DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) @@ -1552,8 +1586,8 @@ def _group_bwd_lora_fused( # dY @ B[e]: [M, N] @ [N, R] -> [M, R] dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32) - dy_b_cast = dy_b.to(x.dtype) - xa_cast = xa.to(dy.dtype) + dy_b_cast = dy_b.to(INPUT_DTYPE) + xa_cast = xa.to(INPUT_DTYPE) # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index 330e133eaa..b850590408 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -59,40 +59,50 @@ def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): ) -def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): - """Convert peft LoRA weights for **down_proj** to scattermoe layout. +def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Convert peft LoRA weights to scattermoe layout (with A<->B swap). + + peft operates on the parameter in its native storage layout ``[E, dim1, dim2]`` + where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the + parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with + ``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles + are swapped relative to scattermoe's convention. - down_proj param: ``[E, hidden, inter]`` - scattermoe W = param.T = ``[E, inter, hidden]`` -> K=inter, N=hidden + peft gives: + lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]`` - peft assigns: in_features=hidden (dim 1), out_features=inter (dim 2). - peft lora_A: ``[r*E, hidden]``, lora_B: ``[inter, r*E]`` + scattermoe needs: + lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]`` - scattermoe needs: lora_A ``[r*E, K=inter]``, lora_B ``[N=hidden, r*E]`` + This function swaps A<->B and converts B from rank-major to expert-major. - The roles of A and B are swapped because peft operates in the parameter's - native layout while scattermoe uses the transposed view. + Works for **both** gate_up_proj and down_proj since the transposition + issue is the same for any parameter. """ peft_B_expert_major = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) - K_inter = peft_B.shape[0] # inter - N_hidden = peft_A.shape[1] # hidden + K = peft_B.shape[0] # dim2 -> becomes scattermoe K + N = peft_A.shape[1] # dim1 -> becomes scattermoe N smoe_A = torch.zeros( - rank * num_experts, K_inter, device=peft_A.device, dtype=peft_A.dtype + rank * num_experts, K, device=peft_A.device, dtype=peft_A.dtype ) smoe_B = torch.zeros( - N_hidden, rank * num_experts, device=peft_A.device, dtype=peft_A.dtype + N, rank * num_experts, device=peft_A.device, dtype=peft_A.dtype ) for e in range(num_experts): s = e * rank - # peft's A_e: [r, hidden], B_e: [inter, r] - A_e = peft_A[s : s + rank, :] # [r, hidden] - B_e = peft_B_expert_major[:, s : s + rank] # [inter, r] - smoe_A[s : s + rank, :] = B_e.T # [r, inter] - smoe_B[:, s : s + rank] = A_e.T # [hidden, r] + A_e = peft_A[s : s + rank, :] # [r, dim1] + B_e = peft_B_expert_major[:, s : s + rank] # [dim2, r] + smoe_A[s : s + rank, :] = B_e.T # [r, dim2=K] + smoe_B[:, s : s + rank] = A_e.T # [dim1=N, r] return smoe_A, smoe_B +def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Deprecated alias for :func:`peft_lora_to_scattermoe`.""" + return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank) + + # ============================================================================= # ParamWrapper unwrapping # ============================================================================= @@ -106,15 +116,17 @@ def _unwrap_gate_lora(gate_module): ParamWrapper(weight) -> base_layer: OlmoeTopKRouter (the real module) - This function detects the wrapping and returns the base router plus - the LoRA-adjusted weight tensor. + This function detects the wrapping and returns the base router, its + weight tensor, and an optional LoRA delta tensor. Returns: - (base_gate, gate_weight) + (base_gate, gate_weight, gate_lora_delta_or_None) ``base_gate`` is the original router module (with ``.top_k``, ``.num_experts``, ``.norm_topk_prob``). - ``gate_weight`` is the router weight with any LoRA delta applied. + ``gate_weight`` is the base router weight (may be a DTensor under FSDP). + ``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active, + else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add. """ if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"): base_gate = gate_module.base_layer @@ -124,13 +136,12 @@ def _unwrap_gate_lora(gate_module): # lora_A: [r, hidden_size], lora_B: [num_experts, r] # delta = scaling * B @ A = [num_experts, hidden_size] delta = scaling * (lora_B @ lora_A) - gate_weight = base_gate.weight + delta + return base_gate, base_gate.weight, delta else: - gate_weight = base_gate.weight - return base_gate, gate_weight + return base_gate, base_gate.weight, None else: # No wrapping — gate is the original module - return gate_module, gate_module.weight + return gate_module, gate_module.weight, None def _unwrap_experts_lora(experts_module): @@ -174,15 +185,14 @@ def _unwrap_experts_lora(experts_module): if gup is not None: num_experts = gup.shape[0] - # Extract gate_up_proj LoRA + # Extract gate_up_proj LoRA (needs A<->B swap due to transposition) gup_lora = None gup_wrapper = wrappers.get("gate_up_proj") if gup_wrapper is not None: lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - smoe_A = lora_A # already expert-major [r*E, K] - smoe_B = peft_lora_B_to_scattermoe(lora_B, num_experts, rank) + smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) gup_lora = (smoe_A, smoe_B, scaling) # Extract down_proj LoRA (needs A<->B swap due to transposition) @@ -192,9 +202,7 @@ def _unwrap_experts_lora(experts_module): lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - smoe_A, smoe_B = peft_down_proj_lora_to_scattermoe( - lora_A, lora_B, num_experts, rank - ) + smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) down_lora = (smoe_A, smoe_B, scaling) return base_experts, gup_lora, down_lora @@ -313,8 +321,12 @@ def forward(self: nn.Module, layer_input: torch.Tensor): # ==================================================================== # Router Computation (with optional gate LoRA) # ==================================================================== - base_gate, gate_weight = _unwrap_gate_lora(self.gate) + base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate) router_logits = F.linear(hidden_states_flat, gate_weight) + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states_flat, gate_lora_delta + ) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) top_k = base_gate.top_k diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py index cedaa671f6..794af0cc8a 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -358,13 +358,15 @@ def _compute_lora_input_grad( (M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype ) + compute_dtype = grouped_grad_out.dtype + prev_offset = 0 for e in range(E): curr_offset = expert_offsets[e].item() if curr_offset > prev_offset: dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N] - a_e = lora_A[e * R : (e + 1) * R, :] # [r, K] - b_e = lora_B[:, e * R : (e + 1) * R] # [N, r] + a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K] + b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r] # dX_e = scaling * (dY_e @ B_e) @ A_e dy_b = dy_e @ b_e # [M_e, r] diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 2222600200..62dcbde7af 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -329,7 +329,7 @@ def _apply_multipack_patches(self): else: has_remote_code = False - if has_remote_code and self.cfg.trust_remote_code is False: + if has_remote_code and self.cfg.trust_remote_code is not None: # If explicitly set in YAML, prefer that has_remote_code = self.cfg.trust_remote_code From 9f9116a97f552b3f0d354d3599d034241b2366cd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Feb 2026 00:01:21 -0500 Subject: [PATCH 3/8] expert weights aren't needed in save for bwd since they are frozen --- .../kernels/libs/scattermoe_lora/layers.py | 64 ++++++++++++++----- .../scattermoe_lora/parallel_linear_lora.py | 23 ++++--- 2 files changed, 60 insertions(+), 27 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index b850590408..3813ebb703 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -75,26 +75,34 @@ def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]`` This function swaps A<->B and converts B from rank-major to expert-major. + Uses vectorized tensor operations (no Python loop over experts). Works for **both** gate_up_proj and down_proj since the transposition issue is the same for any parameter. """ - peft_B_expert_major = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) + peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) - K = peft_B.shape[0] # dim2 -> becomes scattermoe K - N = peft_A.shape[1] # dim1 -> becomes scattermoe N - smoe_A = torch.zeros( - rank * num_experts, K, device=peft_A.device, dtype=peft_A.dtype + dim1 = peft_A.shape[1] # peft in_features -> scattermoe N + dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K + + # smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2] + # [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2] + smoe_A = ( + peft_B_em.reshape(dim2, num_experts, rank) + .permute(1, 2, 0) + .contiguous() + .reshape(rank * num_experts, dim2) ) - smoe_B = torch.zeros( - N, rank * num_experts, device=peft_A.device, dtype=peft_A.dtype + + # smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r] + # [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r] + smoe_B = ( + peft_A.reshape(num_experts, rank, dim1) + .permute(2, 0, 1) + .contiguous() + .reshape(dim1, num_experts * rank) ) - for e in range(num_experts): - s = e * rank - A_e = peft_A[s : s + rank, :] # [r, dim1] - B_e = peft_B_expert_major[:, s : s + rank] # [dim2, r] - smoe_A[s : s + rank, :] = B_e.T # [r, dim2=K] - smoe_B[:, s : s + rank] = A_e.T # [dim1=N, r] + return smoe_A, smoe_B @@ -144,6 +152,26 @@ def _unwrap_gate_lora(gate_module): return gate_module, gate_module.weight, None +def _get_cached_smoe_lora(wrapper, lora_A, lora_B, num_experts, rank, scaling): + """Get scattermoe-layout LoRA weights, using a per-wrapper cache. + + The conversion ``peft_lora_to_scattermoe`` allocates new tensors (up to + ~100 MB for Qwen3-30B-A3B). Under gradient checkpointing, each layer's + forward is replayed during backward, doubling the allocation cost. + Caching avoids this: we store the converted weights on the wrapper module + and invalidate when the optimizer updates the parameters (detected via + ``_version`` counters that PyTorch increments on every in-place op). + """ + version = lora_A._version + lora_B._version + cache = getattr(wrapper, "_smoe_lora_cache", None) + if cache is not None and cache[0] == version: + smoe_A, smoe_B = cache[1], cache[2] + else: + smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) + wrapper._smoe_lora_cache = (version, smoe_A, smoe_B) + return (smoe_A, smoe_B, scaling) + + def _unwrap_experts_lora(experts_module): """Walk a peft ``ParamWrapper`` chain on ``self.experts``. @@ -192,8 +220,9 @@ def _unwrap_experts_lora(experts_module): lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) - gup_lora = (smoe_A, smoe_B, scaling) + gup_lora = _get_cached_smoe_lora( + gup_wrapper, lora_A, lora_B, num_experts, rank, scaling + ) # Extract down_proj LoRA (needs A<->B swap due to transposition) down_lora = None @@ -202,8 +231,9 @@ def _unwrap_experts_lora(experts_module): lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) - down_lora = (smoe_A, smoe_B, scaling) + down_lora = _get_cached_smoe_lora( + down_wrapper, lora_A, lora_B, num_experts, rank, scaling + ) return base_experts, gup_lora, down_lora diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py index 794af0cc8a..f108db50dc 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -92,16 +92,21 @@ def forward( ctx.save_for_backward( x, - expert_weights, lora_A, lora_B, - expert_biases, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, gates, output_expanded, ) + # Store frozen weights as plain Python attributes instead of + # save_for_backward. This avoids: + # 1. Version-check conflicts with FSDP unshard/reshard + # 2. Pinning all-gathered parameters via saved_tensors hooks + # 3. Interfering with activation offloading pack/unpack hooks + # Safe because expert_weights are frozen (requires_grad=False). + ctx.expert_weights = expert_weights ctx.grouped_in = grouped_in ctx.grouped_out = grouped_out ctx.k = k @@ -116,16 +121,15 @@ def backward(ctx, grad_out: torch.Tensor): with torch.device(grad_out.device): ( x, - expert_weights, lora_A, lora_B, - expert_biases, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, gates, output_expanded, ) = ctx.saved_tensors + expert_weights = ctx.expert_weights k = ctx.k scaling = ctx.scaling @@ -292,13 +296,12 @@ def backward(ctx, grad_out: torch.Tensor): scaling, ) if grouped_in: - d_expanded_input = d_expanded_input + d_input_lora_grouped + d_expanded_input.add_(d_input_lora_grouped) else: - d_input_lora_ungrouped = torch.zeros_like(d_expanded_input) - d_input_lora_ungrouped[sorted_scattered_idxs] = ( - d_input_lora_grouped - ) - d_expanded_input = d_expanded_input + d_input_lora_ungrouped + # Scatter-add LoRA gradient directly into d_expanded_input. + # Avoids allocating a zeros_like + add result (~2× 256 MB + # for Qwen3-30B-A3B) that caused OOM at peak memory. + d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped # Reduce over top-k if k > 1 if k == 1: From a1a5627869422235e94931e2ccf82a7a6c6c1bf7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Feb 2026 00:34:24 -0500 Subject: [PATCH 4/8] use sonicmoe optim options --- .../integrations/kernels/libs/scattermoe_lora/layers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index 3813ebb703..cf86c84753 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -395,6 +395,8 @@ def forward(self: nn.Module, layer_input: torch.Tensor): scaling=gup_scaling, grouped_in=False, grouped_out=True, + use_fused_dX=True, + use_fused_gather=True, ) else: gup = parallel_linear( @@ -428,9 +430,11 @@ def forward(self: nn.Module, layer_input: torch.Tensor): lora_A=down_A, lora_B=down_B, scaling=down_scaling, + gates=routing_weights, grouped_in=True, grouped_out=False, - gates=routing_weights, + use_fused_dX=True, + use_fused_gather=True, ) else: expert_output = parallel_linear( From 8cc7cc6db0e576caa2fdd4bf42b69adbeee932be Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Feb 2026 01:21:17 -0500 Subject: [PATCH 5/8] update save model from upstream --- src/axolotl/core/trainers/base.py | 24 ++------ .../libs/scattermoe_lora/kernels/lora_ops.py | 56 +++++++++++++------ .../kernels/libs/scattermoe_lora/layers.py | 28 ++-------- .../scattermoe_lora/parallel_linear_lora.py | 41 +++++++------- 4 files changed, 70 insertions(+), 79 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 77e7b573b3..74ea0ec365 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -719,13 +719,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") - if state_dict is None: - state_dict = self.accelerator.get_state_dict(self.model) - if state_dict is not None: - state_dict = { - k: v.clone() if isinstance(v, torch.Tensor) else v - for k, v in state_dict.items() - } + supported_classes = ( (PreTrainedModel,) if not is_peft_available() @@ -736,6 +730,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() + if isinstance( self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes, @@ -745,6 +740,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): ).save_pretrained( output_dir, state_dict=state_dict, + is_main_process=self.accelerator.is_main_process, ) else: LOG.info( @@ -756,11 +752,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): metadata={"format": "pt"}, ) else: - self.model.save_pretrained( - output_dir, - state_dict=state_dict, - is_main_process=self.accelerator.is_main_process, - ) + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) @@ -772,11 +764,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): LOG.info( "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" ) - save_jinja_files = True - if self.axolotl_cfg: - save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files - self.data_collator.tokenizer.save_pretrained( - output_dir, save_jinja_files=save_jinja_files - ) + self.data_collator.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 352676e378..5d47c20408 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -1425,7 +1425,7 @@ def grid(META): ) @triton.jit def _group_bwd_lora_fused( - # Inputs (ungrouped) + # Inputs (ungrouped or grouped) DY_ptr, stride_dym, stride_dyn, @@ -1468,11 +1468,22 @@ def _group_bwd_lora_fused( allow_tf32: tl.constexpr, NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr, + # Whether DY is already in grouped (expert-sorted) order + dy_grouped: tl.constexpr = False, ): """ Fused gather + LoRA gradient computation. Same as _group_bwd_lora but - reads X and DY from ungrouped buffers using sorted_scattered_idxs for - indirect indexing, eliminating the need for separate group() calls. + reads X from ungrouped buffers using sorted_scattered_idxs for indirect + indexing, eliminating the need for a separate group(X) call. + + When dy_grouped=False (default): both X and DY are read via indirect + indexing through sorted_scattered_idxs. This eliminates both group() + calls entirely. + + When dy_grouped=True: DY is already in grouped order (e.g. gate_up_proj + backward where grouped_out=True) and is read directly. Only X uses + indirect indexing. This avoids the group(X) allocation while + still supporting the grouped DY case. Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N)) @@ -1483,12 +1494,6 @@ def _group_bwd_lora_fused( Supports token rounding: expert_offsets_ptr gives the iteration range (padded to BLOCK_M multiples), real_expert_offsets_ptr gives the real token count for M_mask (to exclude padding tokens). - - Key difference from _group_bwd_lora: - Instead of X_ptr[M_idx, :] and DY_ptr[M_idx, :] on pre-grouped data, - we load scatter_idx = sorted_scattered_idxs[M_idx], then: - X_token_idx = scatter_idx // FAN_OUT (X is [M, K], not expanded) - DY uses scatter_idx directly (DY is [M*k, N] or expanded via gate) """ pid0 = tl.program_id(axis=0) pid1 = tl.program_id(axis=1) @@ -1556,7 +1561,7 @@ def _group_bwd_lora_fused( M_local = i * BLOCK_M + M_block M_mask = M_local < real_num_tokens - # Fused gather: load scatter indices, then indirect-load X and DY + # Fused gather: load scatter indices for indirect X access scatter_idx = tl.load( sorted_scattered_idxs_ptr + M_idx, mask=M_mask, other=0 ).to(tl.int32) @@ -1570,12 +1575,17 @@ def _group_bwd_lora_fused( X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0 ).to(INPUT_DTYPE) - # Load DY via scatter index: DY is [M*k, N] - DY_blk_ptrs = ( - DY_ptr - + scatter_idx[:, None] * stride_dym - + N_block[None, :] * stride_dyn - ) + # Load DY: indirect via scatter_idx when ungrouped, direct via M_idx when grouped + if dy_grouped: + DY_blk_ptrs = ( + DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + else: + DY_blk_ptrs = ( + DY_ptr + + scatter_idx[:, None] * stride_dym + + N_block[None, :] * stride_dyn + ) dy = tl.load( DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0 ).to(INPUT_DTYPE) @@ -1631,6 +1641,7 @@ def group_bwd_lora_fused( k: int, scaling: float, real_expert_offsets: Optional[torch.Tensor] = None, + dy_grouped: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Fused gather + LoRA gradient computation. Same result as @@ -1638,8 +1649,13 @@ def group_bwd_lora_fused( the intermediate grouped buffers. Args: - DY: Gradient w.r.t. output [M*k, N] (ungrouped, original token order) - X: Input [M, K] (ungrouped, original token order) + DY: Gradient w.r.t. output [M*k, N]. + If dy_grouped=False: ungrouped (original token order), read via + indirect indexing through sorted_scattered_idxs. + If dy_grouped=True: already in grouped (expert-sorted) order, + read directly. + X: Input [M, K] (ungrouped, original token order). Always read via + indirect indexing through sorted_scattered_idxs. lora_A: LoRA A weights [r*E, K] lora_B: LoRA B weights [N, r*E] expert_offsets: Cumulative token counts per expert [E] @@ -1651,6 +1667,9 @@ def group_bwd_lora_fused( scaling: LoRA scaling factor real_expert_offsets: Original cumulative counts for M_mask when using token rounding. If None, expert_offsets is used for both. + dy_grouped: Whether DY is already in grouped order (default False). + When True, avoids indirect indexing for DY, used for gate_up_proj + backward where grouped_out=True. Returns: dA: Gradient for A [r*E, K] @@ -1706,6 +1725,7 @@ def grid(META): scaling=scaling, ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, + dy_grouped=dy_grouped, ) return dA, dB diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index cf86c84753..cd2ad132a6 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -152,23 +152,9 @@ def _unwrap_gate_lora(gate_module): return gate_module, gate_module.weight, None -def _get_cached_smoe_lora(wrapper, lora_A, lora_B, num_experts, rank, scaling): - """Get scattermoe-layout LoRA weights, using a per-wrapper cache. - - The conversion ``peft_lora_to_scattermoe`` allocates new tensors (up to - ~100 MB for Qwen3-30B-A3B). Under gradient checkpointing, each layer's - forward is replayed during backward, doubling the allocation cost. - Caching avoids this: we store the converted weights on the wrapper module - and invalidate when the optimizer updates the parameters (detected via - ``_version`` counters that PyTorch increments on every in-place op). - """ - version = lora_A._version + lora_B._version - cache = getattr(wrapper, "_smoe_lora_cache", None) - if cache is not None and cache[0] == version: - smoe_A, smoe_B = cache[1], cache[2] - else: - smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) - wrapper._smoe_lora_cache = (version, smoe_A, smoe_B) +def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling): + """Convert peft LoRA weights to scattermoe layout.""" + smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) return (smoe_A, smoe_B, scaling) @@ -220,9 +206,7 @@ def _unwrap_experts_lora(experts_module): lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - gup_lora = _get_cached_smoe_lora( - gup_wrapper, lora_A, lora_B, num_experts, rank, scaling - ) + gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling) # Extract down_proj LoRA (needs A<->B swap due to transposition) down_lora = None @@ -231,9 +215,7 @@ def _unwrap_experts_lora(experts_module): lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper) if lora_A is not None: rank = lora_A.shape[0] // num_experts - down_lora = _get_cached_smoe_lora( - down_wrapper, lora_A, lora_B, num_experts, rank, scaling - ) + down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling) return base_experts, gup_lora, down_lora diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py index f108db50dc..5d00e1230d 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -107,6 +107,7 @@ def forward( # 3. Interfering with activation offloading pack/unpack hooks # Safe because expert_weights are frozen (requires_grad=False). ctx.expert_weights = expert_weights + ctx.expert_biases = expert_biases ctx.grouped_in = grouped_in ctx.grouped_out = grouped_out ctx.k = k @@ -153,14 +154,16 @@ def backward(ctx, grad_out: torch.Tensor): gate_fan = 1 grouped_grad_out = None - # Fused gather is only possible when data isn't already grouped - # and there are no gate coefficients (which require a multiplicative gather). + # ------------------------------------------------------------------ + # LoRA gradients (dA, dB) and setup for dX + # ------------------------------------------------------------------ + # Fused gather uses sorted_scattered_idxs for indirect X access + # in the Triton kernel, avoiding the group(x) allocation. # - # Heuristic: fused gather eliminates group() calls but uses random - # access (via sorted_scattered_idxs) in the inner GEMM loop. At - # large problem sizes the GEMM dominates runtime and sequential - # access from group() is faster than the random scatter loads. - # Disable fused gather when the total workload exceeds a threshold. + # can_fuse_gather: X is ungrouped and not too large for scatter loads + # - When gates is None and grouped_out=False: both DY and X ungrouped + # - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped + # -> use dy_grouped=True in the fused kernel M_total = sorted_scattered_idxs.size(0) K_dim = x.size(-1) N_dim = expert_weights.size(-1) @@ -169,18 +172,15 @@ def backward(ctx, grad_out: torch.Tensor): can_fuse_gather = ( ctx.use_fused_gather - and not grouped_in - and not grouped_out - and gates is None + and not grouped_in # X must be ungrouped for scatter access + and gates is None # gate coeff requires multiplicative gather and fuse_gather_workload < _FUSE_GATHER_THRESHOLD ) if can_fuse_gather: # ------------------------------------------------------------------ - # Fused path: skip BOTH group() calls entirely + # Fused path: skip group(x) entirely # ------------------------------------------------------------------ - # group_bwd_lora_fused reads ungrouped DY and X directly - # scatter2scatter_lora_dX (if used) reads ungrouped DY via scatter pattern d_expanded_input = None d_lora_A, d_lora_B = group_bwd_lora_fused( @@ -193,12 +193,14 @@ def backward(ctx, grad_out: torch.Tensor): E=E, k=k, scaling=scaling, + dy_grouped=grouped_out, ) - # When using fused gather, we need grouped_grad_out only if - # the dX path is NOT fused (original path needs it grouped). - # If fused dX is also enabled, it can read ungrouped DY directly. - if not ctx.use_fused_dX: + # Prepare grouped_grad_out for the dX path (needed by both + # the fused dX kernel when grouped_out=True, and the non-fused path) + if grouped_out: + grouped_grad_out = grad_out + elif not ctx.use_fused_dX: grouped_grad_out = base_ops.group( grad_out, sorted_scattered_idxs, @@ -242,7 +244,7 @@ def backward(ctx, grad_out: torch.Tensor): # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A # ------------------------------------------------------------------ if ctx.use_fused_dX: - if can_fuse_gather: + if can_fuse_gather and not grouped_out: # Fully fused: read ungrouped DY via scatter pattern d_expanded_input = scatter2scatter_lora_dX( DY=grad_out, @@ -299,8 +301,7 @@ def backward(ctx, grad_out: torch.Tensor): d_expanded_input.add_(d_input_lora_grouped) else: # Scatter-add LoRA gradient directly into d_expanded_input. - # Avoids allocating a zeros_like + add result (~2× 256 MB - # for Qwen3-30B-A3B) that caused OOM at peak memory. + # Avoids allocating a zeros_like + add result d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped # Reduce over top-k if k > 1 From 417688c27aaa64ee6cd659bdf0e7498ba8ef727e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 15 Feb 2026 10:25:22 -0500 Subject: [PATCH 6/8] fixes per code review feedback and add tests --- src/axolotl/integrations/kernels/args.py | 1 + .../libs/scattermoe_lora/kernels/ops.py | 12 +- .../libs/scattermoe_lora/kernels/single.py | 11 +- .../kernels/libs/scattermoe_lora/layers.py | 2 - .../kernels/libs/scattermoe_lora/lora_ops.py | 2 +- .../libs/scattermoe_lora/parallel_experts.py | 2 - src/axolotl/integrations/kernels/plugin.py | 1 - tests/integrations/test_scattermoe_lora.py | 323 ++++++++++++++++++ 8 files changed, 338 insertions(+), 16 deletions(-) create mode 100644 tests/integrations/test_scattermoe_lora.py diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index 78050ddc92..e8cf7208a0 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -42,6 +42,7 @@ def disable_mlp_kernel_scattermoe(cls, data): LOG.warning( "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues." ) + data["lora_mlp_kernel"] = False data["mlp_kernel"] = False return data diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py index 6850dc6b40..6aa432770d 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py @@ -223,9 +223,9 @@ def grid(META): if b is None: b = None - stride_be = stride_bk = 0 + stride_be = stride_bn = 0 else: - stride_be, stride_bk = b.stride() + stride_be, stride_bn = b.stride() _scatter2scatter[grid]( # X_ptr, stride_xm, stride_xk, @@ -241,10 +241,10 @@ def grid(META): output, output.stride(0), output.stride(1), - # B_ptr, stride_be, stride_bk + # B_ptr, stride_be, stride_bn b, stride_be, - stride_bk, + stride_bn, grouped_idx_ptr=sorted_scattered_idxs, expert_idxs_ptr=sorted_expert_idxs, # block_start_idx_ptr=padded_block_idxs, @@ -280,7 +280,7 @@ def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): return DW, Db -@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) +@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"}) def groupXtY_compileable( E: int, DW: torch.Tensor, @@ -560,7 +560,7 @@ def group_compileable( K: int, N: int, Y: torch.Tensor, - coeff: torch.Tensor, + coeff: Optional[torch.Tensor], has_coeff: bool, fan_out: int, sorted_expert_idxs: torch.Tensor, diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py index 20c0dcf183..9f0270aa67 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py @@ -53,15 +53,18 @@ def _single2scatter( + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn ) + N_mask = N_block < N acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)): - x = tl.load(X_blk_ptrs) - w = tl.load(W_blk_ptrs) + K_mask = K_block < K + x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0) acc += tl.sum(x * w, axis=0)[None, :] X_blk_ptrs += BLOCK_K * stride_xk W_blk_ptrs += BLOCK_K * stride_wk + K_block += BLOCK_K Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn - tl.store(Y_blk_ptrs, acc) + tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :]) def single2scatter(X, W, expert_idxs): @@ -71,7 +74,7 @@ def single2scatter(X, W, expert_idxs): Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) BLOCK_N = 128 BLOCK_K = 128 - grid = ydim // BLOCK_N, k + grid = triton.cdiv(ydim, BLOCK_N), k _single2scatter[grid]( X, X.stride(0), diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index cd2ad132a6..a425774833 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -237,8 +237,6 @@ def forward(self, layer_input): Returns: Tensor: Output tensor. - Tensor: - Router logits. """ bsz, length, emb_size = layer_input.size() layer_input = layer_input.reshape(-1, emb_size) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py index e4b40660ab..aec68311b6 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py @@ -91,7 +91,7 @@ def forward( expert_offsets, lora_A=self._lora_A, lora_B=self._lora_B, - scaling=self._lora_scaling or 1.0, + scaling=self._lora_scaling if self._lora_scaling is not None else 1.0, expert_biases=self.bias, gates=gates, grouped_in=grouped_in, diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py index 2aae050bb6..7a1eef472b 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py @@ -99,7 +99,6 @@ def backward(ctx, grad_out: torch.Tensor): k = ctx.k grouped_in = ctx.grouped_in grouped_out = ctx.grouped_out - # print("backward") if gates is not None: # calculate gates gradient @@ -158,7 +157,6 @@ def backward(ctx, grad_out: torch.Tensor): d_input = d_expanded_input.view( x.size(0), k, d_expanded_input.size(-1) ).sum(-2) - # print("backward end.") return ( # x, expert_weights, d_input, diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 9bcf3930b3..56d0448d5f 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,7 +1,6 @@ from pathlib import Path from kernels import ( - # LayerRepository, LocalLayerRepository, Mode, register_kernel_mapping, diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py new file mode 100644 index 0000000000..859119c819 --- /dev/null +++ b/tests/integrations/test_scattermoe_lora.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Unit tests for scattermoe-lora code-review fixes. + +Tests cover: +- KernelsArgs validator: disable_mlp_kernel_scattermoe +- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward +- ParallelExperts: scaling=0.0 not treated as falsy +- single2scatter: non-aligned K/N dimensions +- group_compileable: coeff=None accepted +- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract +""" + +from unittest.mock import patch + +import pytest +import torch + +# ============================================================================ +# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator +# ============================================================================ + + +class TestKernelsArgsValidator: + """Test that disable_mlp_kernel_scattermoe sets both flags correctly. + + These tests call the validator classmethod directly on raw dicts, + since lora_mlp_kernel / mlp_kernel are not declared model fields. + """ + + def test_disables_lora_mlp_kernel_when_scattermoe(self): + """lora_mlp_kernel=True gets set to False when use_scattermoe=True.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + "lora_mlp_kernel": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is False + assert result["mlp_kernel"] is False + + def test_mlp_kernel_disabled_without_lora(self): + """Even without lora_mlp_kernel, mlp_kernel should be disabled.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["mlp_kernel"] is False + # lora_mlp_kernel was not in data, should not be added + assert "lora_mlp_kernel" not in result + + def test_lora_mlp_kernel_false_unchanged(self): + """lora_mlp_kernel=False should stay False (no warning, no change).""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + "lora_mlp_kernel": False, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is False + + def test_no_change_when_scattermoe_disabled(self): + """When use_scattermoe is not True, nothing should be changed.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": False, + "lora_mlp_kernel": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is True + + +class TestParallelExpertsScaling: + """Test that scaling=0.0 is preserved and not overridden to 1.0.""" + + def test_scaling_zero_preserved(self): + """scaling=0.0 should be passed as 0.0, not replaced with 1.0.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + pe.set_lora( + lora_A=torch.randn(4, 4), + lora_B=torch.randn(4, 4), + scaling=0.0, + ) + assert pe._lora_scaling == 0.0 + + # Patch parallel_linear_lora to capture the scaling arg + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + # Create dummy routing tensors + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + # Check that scaling=0.0 was passed, not 1.0 + call_kwargs = mock_pll.call_args + assert ( + call_kwargs.kwargs.get("scaling") == 0.0 + or call_kwargs[1].get("scaling") == 0.0 + ), f"Expected scaling=0.0 but got {call_kwargs}" + + def test_scaling_none_defaults_to_one(self): + """scaling=None (no LoRA attached) should default to 1.0.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + # No set_lora called, so _lora_scaling is None + + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + call_kwargs = mock_pll.call_args + scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get( + "scaling" + ) + assert scaling_val == 1.0, ( + f"Expected scaling=1.0 for None but got {scaling_val}" + ) + + def test_scaling_positive_preserved(self): + """Normal positive scaling should be preserved.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + pe.set_lora( + lora_A=torch.randn(4, 4), + lora_B=torch.randn(4, 4), + scaling=0.5, + ) + + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + call_kwargs = mock_pll.call_args + scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get( + "scaling" + ) + assert scaling_val == 0.5 + + +# ============================================================================ +# 4. single2scatter: non-aligned K/N dimensions (GPU only) +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestSingle2ScatterBounds: + """Test single2scatter with non-aligned dimensions.""" + + def test_non_aligned_k(self): + """K not a multiple of BLOCK_K should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 100, 128 # K=100 not a multiple of 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + # Verify against manual computation + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + def test_non_aligned_n(self): + """N not a multiple of BLOCK_N should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 128, 100 # N=100 not a multiple of 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + def test_non_aligned_both(self): + """Both K and N not aligned should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 100, 100 # Neither aligned to 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + +# ============================================================================ +# 5. group_compileable: coeff=None accepted +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGroupCoeffNone: + """Test that group() works with coeff=None.""" + + def test_group_with_none_coeff(self): + """group() should accept coeff=None without errors.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group + + M, K = 4, 32 + A = torch.randn(M, K, device="cuda", dtype=torch.float32) + sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long) + + # This should not raise a TypeError + Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1) + assert Y.shape == (M, K) + + def test_group_with_coeff(self): + """group() should also work with actual coeff values.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group + + M, K = 4, 32 + A = torch.randn(M, K, device="cuda", dtype=torch.float32) + sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long) + coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5 + + Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1) + assert Y.shape == (M, K) + + +# ============================================================================ +# 6. Layer return value contracts +# ============================================================================ + + +class TestLayerReturnValues: + """Test that layer forward methods return the correct types.""" + + def test_hf_scatter_moe_returns_single_tensor(self): + """HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple.""" + pytest.importorskip("triton") + # Verify the forward method signature and return annotation + import inspect + + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + sig = inspect.signature(HFScatterMoEGatedMLP.forward) + # It's a staticmethod taking (self, layer_input) + params = list(sig.parameters.keys()) + assert "self" in params + assert "layer_input" in params + + def test_scatter_moe_gated_mlp_docstring_no_router_logits(self): + """ScatterMoEGatedMLP.forward docstring should not mention router logits as return.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + ScatterMoEGatedMLP, + ) + + docstring = ScatterMoEGatedMLP.forward.__doc__ + assert docstring is not None + # The docstring should mention output tensor but NOT router logits + assert "Output tensor" in docstring or "output tensor" in docstring.lower() + assert "Router logits" not in docstring, ( + "Docstring should not mention 'Router logits' in Returns section" + ) From b1020e6809ca77438ae43225f91f188a274e41b4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 20 Feb 2026 09:04:23 -0500 Subject: [PATCH 7/8] revert removal of CP fix --- src/axolotl/core/trainers/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 74ea0ec365..414abeb4d6 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -720,6 +720,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") + # fix for Context Parallel save + if state_dict is None: + state_dict = self.accelerator.get_state_dict(self.model) + if state_dict is not None: + state_dict = { + k: v.clone() if isinstance(v, torch.Tensor) else v + for k, v in state_dict.items() + } + supported_classes = ( (PreTrainedModel,) if not is_peft_available() From a00b11e0d7235ecb7c21054cb4b4408c0681227a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Feb 2026 11:30:23 -0500 Subject: [PATCH 8/8] misc fixes --- src/axolotl/utils/data/lock.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/data/lock.py b/src/axolotl/utils/data/lock.py index afd1547af6..9699f60e80 100644 --- a/src/axolotl/utils/data/lock.py +++ b/src/axolotl/utils/data/lock.py @@ -54,15 +54,19 @@ def _increment_counter(self): def cleanup(self): """Clean up ready flag when last process is done.""" - with FileLock(str(self.lock_file_path)): - counter_content = self.counter_path.read_text().strip() - count = int(counter_content) if counter_content else 0 - count -= 1 + try: + with FileLock(str(self.lock_file_path)): + counter_content = self.counter_path.read_text().strip() + count = int(counter_content) if counter_content else 0 + count -= 1 - if count <= 0: - # Last process cleans everything up - self.ready_flag_path.unlink(missing_ok=True) - self.counter_path.unlink(missing_ok=True) - else: - # Still have active processes - self.counter_path.write_text(str(count)) + if count <= 0: + # Last process cleans everything up + self.ready_flag_path.unlink(missing_ok=True) + self.counter_path.unlink(missing_ok=True) + else: + # Still have active processes + self.counter_path.write_text(str(count)) + except FileNotFoundError: + # Lock file might have already been deleted by another process + pass