diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 184bfafbe..80a668748 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -762,6 +762,21 @@ def _build_v4_indexer_meta( np.cumsum(n_committed_per_seq, dtype=np.int64), ] ) + # Empty-batch guard: when no seq has committed K yet + # (`cu_committed_cpu[-1] == 0`, e.g. fresh prefill with prompt + # shorter than the CSA `ratio`), `cp_gather_indexer_k_quant_cache` + # would launch with grid.x = 0 and fail with HIP "invalid + # configuration argument". Bump the last cumsum by one so the + # kernel sees a single dummy row to gather (charged to the last + # seq's first cache block). Downstream readers + # (`fp8_mqa_logits` + `top_k_per_row_prefill`) honor per-token + # `cu_starts`/`cu_ends` derived from `cu_committed_gpu[:-1]` and + # `n_committed_per_seq`, both of which remain 0 — so the dummy + # row is never read and the output is `-1` sentinels everywhere, + # matching the all-empty semantics. Pure host-side scalar + # arithmetic on a value already host-synced two lines up; no new + # CG/torch.compile graph branch is introduced. + cu_committed_cpu[-1] = max(int(cu_committed_cpu[-1]), 1) total_committed = int(cu_committed_cpu[-1]) # FP8 write-side slot_mapping (independent of `total_committed` — @@ -772,15 +787,6 @@ def _build_v4_indexer_meta( csa_compress_plan_cpu, scheduled_bs, k_per_block, ratio ) - # All-empty batch: forward_batched short-circuits on - # `total_committed == 0` and returns -1; the FP8 read side is unused. - if total_committed == 0: - return { - "total_committed": 0, - "cu_committed_gpu": None, - "compress_slot_mapping_gpu": compress_slot_mapping_gpu, - } - # batch_id_per_token + n_committed_csa: reuse the shared GPU # tensors set in `_attach_v4_per_fwd_meta` (which MUST run before # this helper — see prepare_decode/prefill ordering). int64 diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 6fae30520..96aafcfe9 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -18,7 +18,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch +from contextlib import contextmanager from typing import Any import logging from math import prod @@ -41,6 +43,45 @@ ) +@contextmanager +def _amd_smem_safe_tile(): + """Cap matmul_ogs tile size on AMD CDNA4 to fit MI355X's 160 KiB LDS. + + triton_kernels' AMD opt_flags has a special-case + `if cdna4 and block_m == 128: block_n = 512`, which makes BLOCK_M*BLOCK_N + = 64K FP32 entries — large enough that triton 3.6+/3.7+ spills the + accumulator into LDS and overflows the 160 KiB budget (observed 269 KiB + on V4-Pro FP8 MoE). triton 3.5 happened to keep more of the acc in + registers and slipped under the limit, hence the version-dependent OOM. + + Pin block_n ≤ ATOM_TRITON_MOE_MAX_BLOCK_N (default 256) so BLOCK_M*BLOCK_N + stays at 32K. Default block_n in compute_block_nk is already capped at + 256 except for that single cdna4 branch, so this only sidesteps the bad + path on gfx950. + """ + if get_gfx() != "gfx950" or not has_triton_kernels(): + yield + return + try: + from triton_kernels.matmul_ogs_details.opt_flags import ( + update_opt_flags_constraints, + reset_opt_flags_constraints, + ) + except ImportError: + yield + return + # Defaults chosen so BLOCK_M*BLOCK_N stays ≤ 16384 entries (64 KiB FP32 + # acc), comfortably fitting MI355X's register file. Override via env if + # a future compiler/kernel update relaxes the budget. + block_m = int(os.getenv("ATOM_TRITON_MOE_BLOCK_M", "64")) + block_n = int(os.getenv("ATOM_TRITON_MOE_BLOCK_N", "256")) + update_opt_flags_constraints({"block_m": block_m, "block_n": block_n}) + try: + yield + finally: + reset_opt_flags_constraints() + + def _swizzle_mxfp4(quant_tensor, scale): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" assert has_triton_kernels() @@ -241,16 +282,17 @@ def triton_kernel_fused_experts( dtype=hidden_states.dtype, ) - matmul_ogs( - hidden_states, - w1, - w1_bias, - routing_data, - gather_indx=gather_indx, - precision_config=w13_precision_config, - gammas=gammas if apply_router_weight_on_input else None, - y=raw_intermediate, - ) + with _amd_smem_safe_tile(): + matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w13_precision_config, + gammas=gammas if apply_router_weight_on_input else None, + y=raw_intermediate, + ) # Standard SiLU/SwiGLU activation: silu(gate) * up # With optional swiglu_limit clamping (V4: limit=10.0) @@ -262,16 +304,17 @@ def triton_kernel_fused_experts( up = up.clamp(-swiglu_limit, swiglu_limit) intermediate_cache[0] = torch.nn.functional.silu(gate) * up - matmul_ogs( - intermediate_cache.view(M * topk, half_N), - w2, - w2_bias, - routing_data, - scatter_indx=scatter_indx, - precision_config=w2_precision_config, - gammas=None if apply_router_weight_on_input else gammas, - y=output_tensor, - ) + with _amd_smem_safe_tile(): + matmul_ogs( + intermediate_cache.view(M * topk, half_N), + w2, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_precision_config, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) output_tensor = output_tensor.view(M, K) return output_tensor