Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions atom/model_ops/attentions/deepseek_v4_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,21 @@ def _build_v4_indexer_meta(
np.cumsum(n_committed_per_seq, dtype=np.int64),
]
)
# Empty-batch guard: when no seq has committed K yet
# (`cu_committed_cpu[-1] == 0`, e.g. fresh prefill with prompt
# shorter than the CSA `ratio`), `cp_gather_indexer_k_quant_cache`
# would launch with grid.x = 0 and fail with HIP "invalid
# configuration argument". Bump the last cumsum by one so the
# kernel sees a single dummy row to gather (charged to the last
# seq's first cache block). Downstream readers
# (`fp8_mqa_logits` + `top_k_per_row_prefill`) honor per-token
# `cu_starts`/`cu_ends` derived from `cu_committed_gpu[:-1]` and
# `n_committed_per_seq`, both of which remain 0 — so the dummy
# row is never read and the output is `-1` sentinels everywhere,
# matching the all-empty semantics. Pure host-side scalar
# arithmetic on a value already host-synced two lines up; no new
# CG/torch.compile graph branch is introduced.
cu_committed_cpu[-1] = max(int(cu_committed_cpu[-1]), 1)
total_committed = int(cu_committed_cpu[-1])

# FP8 write-side slot_mapping (independent of `total_committed` —
Expand All @@ -772,15 +787,6 @@ def _build_v4_indexer_meta(
csa_compress_plan_cpu, scheduled_bs, k_per_block, ratio
)

# All-empty batch: forward_batched short-circuits on
# `total_committed == 0` and returns -1; the FP8 read side is unused.
if total_committed == 0:
return {
"total_committed": 0,
"cu_committed_gpu": None,
"compress_slot_mapping_gpu": compress_slot_mapping_gpu,
}

# batch_id_per_token + n_committed_csa: reuse the shared GPU
# tensors set in `_attach_v4_per_fwd_meta` (which MUST run before
# this helper — see prepare_decode/prefill ordering). int64
Expand Down
83 changes: 63 additions & 20 deletions atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
from contextlib import contextmanager
from typing import Any
import logging
from math import prod
Expand All @@ -41,6 +43,45 @@
)


@contextmanager
def _amd_smem_safe_tile():
"""Cap matmul_ogs tile size on AMD CDNA4 to fit MI355X's 160 KiB LDS.

triton_kernels' AMD opt_flags has a special-case
`if cdna4 and block_m == 128: block_n = 512`, which makes BLOCK_M*BLOCK_N
= 64K FP32 entries — large enough that triton 3.6+/3.7+ spills the
accumulator into LDS and overflows the 160 KiB budget (observed 269 KiB
on V4-Pro FP8 MoE). triton 3.5 happened to keep more of the acc in
registers and slipped under the limit, hence the version-dependent OOM.

Pin block_n ≤ ATOM_TRITON_MOE_MAX_BLOCK_N (default 256) so BLOCK_M*BLOCK_N
stays at 32K. Default block_n in compute_block_nk is already capped at
256 except for that single cdna4 branch, so this only sidesteps the bad
Comment on lines +57 to +59
path on gfx950.
"""
if get_gfx() != "gfx950" or not has_triton_kernels():
yield
return
try:
from triton_kernels.matmul_ogs_details.opt_flags import (
update_opt_flags_constraints,
reset_opt_flags_constraints,
)
except ImportError:
yield
return
# Defaults chosen so BLOCK_M*BLOCK_N stays ≤ 16384 entries (64 KiB FP32
# acc), comfortably fitting MI355X's register file. Override via env if
# a future compiler/kernel update relaxes the budget.
block_m = int(os.getenv("ATOM_TRITON_MOE_BLOCK_M", "64"))
block_n = int(os.getenv("ATOM_TRITON_MOE_BLOCK_N", "256"))
update_opt_flags_constraints({"block_m": block_m, "block_n": block_n})
Comment on lines +74 to +78
try:
yield
finally:
reset_opt_flags_constraints()


def _swizzle_mxfp4(quant_tensor, scale):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
assert has_triton_kernels()
Expand Down Expand Up @@ -241,16 +282,17 @@ def triton_kernel_fused_experts(
dtype=hidden_states.dtype,
)

matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w13_precision_config,
gammas=gammas if apply_router_weight_on_input else None,
y=raw_intermediate,
)
with _amd_smem_safe_tile():
matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w13_precision_config,
gammas=gammas if apply_router_weight_on_input else None,
y=raw_intermediate,
)

# Standard SiLU/SwiGLU activation: silu(gate) * up
# With optional swiglu_limit clamping (V4: limit=10.0)
Expand All @@ -262,16 +304,17 @@ def triton_kernel_fused_experts(
up = up.clamp(-swiglu_limit, swiglu_limit)
intermediate_cache[0] = torch.nn.functional.silu(gate) * up

matmul_ogs(
intermediate_cache.view(M * topk, half_N),
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision_config,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
with _amd_smem_safe_tile():
matmul_ogs(
intermediate_cache.view(M * topk, half_N),
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision_config,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)

output_tensor = output_tensor.view(M, K)
return output_tensor
Loading