feat(top_k_per_row): expose parametric k via runtime arg#3012
Merged
Conversation
The radix backend (`standalone_stable_radix_11bits`) supports arbitrary k
at runtime, but the public C++/Python wrappers hardcoded `kTopK = 2048`.
Plumb `k` through end-to-end so callers (e.g. DeepSeek-V4 indexer with
`k = index_topk = 1024`) can use the same kernel without forcing 2048.
Changes:
- csrc/include/topk_per_row.h: add `int64_t k = 2048` default arg to
both `top_k_per_row_prefill` and `top_k_per_row_decode` (backward
compatible).
- csrc/include/rocm_ops.hpp: pybind expose `py::arg(\"k\") = 2048`.
- csrc/kernels/topk_per_row_kernels.cu:
* `invokeComputeTopkLastDimWorkspaceSize` accepts `int k_param =
2048` and threads it into the workspace-size calc + radix kernel.
* `top_k_per_row_prefill/decode` bodies replace
`static constexpr int kTopK = 2048` with `static_cast<int>(k)`.
* Explicit `<float>` instantiation updated to match new sig.
- aiter/ops/topk.py: add `k: int = 2048` to the regular Python sigs
(`top_k_per_row_prefill`, `top_k_per_row_decode`) so the
`@compile_ops`-generated torch op schema accepts the new kwarg. The
`_fast` ASM-kernel variants are left untouched — their precompiled
`.co` blobs hardcode k=2048 and cannot honor a runtime k.
- op_tests/test_topk_per_row.py:
* Default `--top_k` widened from `[2048]` to `[512, 1024, 2048]`
so CI exercises the parametric path.
* `run_top_k_per_row_{prefill,decode}` thread `k` through.
* `_fast` decode is skipped when k != 2048 (kernel is hardcoded
and would write 2048 ints into a smaller buffer).
* `run_top_k_per_row_decode` asserts `k == 2048` for `_fast`.
Backward compatibility: every caller without a `k` arg continues to get
k=2048, matching previous behavior. No existing test changes signature.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
valarLip
added a commit
to ROCm/ATOM
that referenced
this pull request
May 3, 2026
Replaces the torch.topk + -inf fill path in `Indexer._score_topk_*`
with aiter `top_k_per_row_decode/prefill` (radix kernel, parametric k).
Both paths emit a uniform [total_tokens, index_topk] int32 layout.
_score_topk_decode (CG-friendly path):
- Pre-allocated [max_bs, index_topk] int32 indices buffer in builder.
- Pre-allocated [max_bs, max_model_len_idx] fp32 logits buffer.
- Drop `fill_(-inf)`: top_k_per_row_decode honors n_committed_per_seq
per row, so logits cells past valid range are never read.
- Drop torch.topk + .to(int32) cast.
_score_topk_prefill (eager-only path):
- Drop torch.topk + dynamic-`max_k` shape; emit
[total_tokens, index_topk] via top_k_per_row_prefill(k=index_topk),
kernel writes -1 sentinels in tail cols.
- Per-fwd torch.empty for indices (prefill total_tokens dynamic).
Builder _build_v4_indexer_meta:
- v4_indexer_n_committed_per_seq buffer i64 -> i32 (kernel arg dtype).
- Add v4_indexer_decode_logits and v4_indexer_decode_topk_indices
forward_vars buffers.
- width_mask collapses to uniform [total_tokens, index_topk] bool.
- Drop max_k from returned dict; empty-batch guard now keys on
total_committed == 0.
Builder _build_v4_pack_meta_for_ratio:
- compress_topk_src stride is `index_topk` for both paths (was the
dynamic max_k = max(k_per_seq), which assumed prefill's
torch.topk(max_k) output shape).
_post_process_topk:
- Input contract changes to [total_tokens, index_topk] uniform layout.
Depends on ROCm/aiter#3012 (exposes `k` kwarg on top_k_per_row_decode /
top_k_per_row_prefill); existing aiter without that PR will silently
ignore the kwarg and run with k=2048 (still correct, but allocates an
oversized output buffer).
Validation:
- aiter kernel parity at v4 shapes (k=1024, varying bs/ctx) - all OK.
- GSM8K-100 num_fewshot=3 eager: 0.97 / 0.97 (stable vs 0.96 baseline).
Contributor
There was a problem hiding this comment.
Pull request overview
This PR exposes a runtime-configurable k for the radix-backed top_k_per_row_prefill / top_k_per_row_decode ops (previously hardcoded to 2048), wiring the parameter through C++ wrappers, pybind, Python op schema, and the operator test script.
Changes:
- Add
k(default 2048) to the C++ public APIs and pybind bindings fortop_k_per_row_{prefill,decode}. - Thread runtime
kinto the radix workspace sizing and kernel invocation intopk_per_row_kernels.cu. - Extend
op_tests/test_topk_per_row.pyto test multiplekvalues and guard_fastdecode tok==2048.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
csrc/include/topk_per_row.h |
Adds defaulted k parameter to the public C++ function declarations. |
csrc/include/rocm_ops.hpp |
Exposes k to Python via pybind with default 2048. |
csrc/kernels/topk_per_row_kernels.cu |
Threads runtime k into workspace sizing + radix top-k calls. |
aiter/ops/topk.py |
Updates @compile_ops function signatures to accept k with default 2048. |
op_tests/test_topk_per_row.py |
Passes k through to ops; expands default --top_k list; skips _fast decode when k!=2048. |
Comments suppressed due to low confidence (1)
csrc/kernels/topk_per_row_kernels.cu:2444
invokeComputeTopkLastDimWorkspaceSizesignature changed to includek_param. There are other translation units (e.g.csrc/kernels/topk_plain_kernels.cu) that forward-declare/extern templatethis function with the old 2-arg signature and call it without passingk, which will cause an unresolved symbol or incorrect workspace sizing once this lands. Update those declarations/calls to match the new signature (and pass the runtimekwhere appropriate).
template <typename T, aiter::Phase phase = aiter::Phase::Prefill>
int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows,
int32_t stride0,
int k_param = 2048)
{
using IdxT = int32_t;
size_t buf_size = 0;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
2516
to
2524
| size_t buf_size = 0; // will be overwritten by the kernel | ||
|
|
||
| static constexpr int kTopK = 2048; | ||
| int kTopK = static_cast<int>(k); | ||
| static constexpr bool is_largest = true; | ||
|
|
||
| const hipStream_t stream = at::hip::getCurrentHIPStream(); | ||
| int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize<float>(numRows, stride0); | ||
| int64_t workspace_size = | ||
| invokeComputeTopkLastDimWorkspaceSize<float>(numRows, stride0, kTopK); | ||
| // int64_t workspace_size = int64_t(1024)*1024*1024*2; |
Comment on lines
2641
to
+2649
| { | ||
| size_t buf_size = 0; // will be overwritten by the kernel | ||
|
|
||
| static constexpr int kTopK = 2048; | ||
| int kTopK = static_cast<int>(k); | ||
| static constexpr bool is_largest = true; | ||
|
|
||
| const hipStream_t stream = at::hip::getCurrentHIPStream(); | ||
| int64_t workspace_size = | ||
| invokeComputeTopkLastDimWorkspaceSize<float, aiter::Phase::Decode>(numRows, stride0); | ||
| int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize<float, aiter::Phase::Decode>( | ||
| numRows, stride0, kTopK); |
Comment on lines
329
to
+338
| parser.add_argument( | ||
| "-k", | ||
| "--top_k", | ||
| type=int, | ||
| default=[2048], | ||
| default=[512, 1024, 2048], | ||
| nargs="+", | ||
| help="""top-k elements per row. | ||
| e.g.: -k 2048""", | ||
| help="""top-k elements per row. The radix backend supports any positive | ||
| int; the `_fast` ASM-kernel path only supports 2048 and is skipped | ||
| for other values. | ||
| e.g.: -k 512 1024 2048""", |
11 tasks
Commit 4027229 added `int k_param = 2048` to `invokeComputeTopkLastDimWorkspaceSize` in `csrc/kernels/topk_per_row_kernels.cu` (definition + explicit instantiation), but missed the forward declaration / extern template in `csrc/kernels/topk_plain_kernels.cu`. Both TUs link into `module_topk_plain.so`, so topk_plain emitted an undefined reference to the old 2-arg symbol while topk_per_row only exported the 3-arg one: ImportError: module_topk_plain.so: undefined symbol: _Z37invokeComputeTopkLastDimWorkspaceSizeIfLN5aiter5PhaseE0EElii -> long invokeComputeTopkLastDimWorkspaceSize<float, (aiter::Phase)0>(int, int) Add the matching `int k_param = 2048` to the forward decl and extern template so the call site at line 2499 mangles to the 3-arg symbol that the explicit instantiation provides.
valarLip
added a commit
to ROCm/ATOM
that referenced
this pull request
May 6, 2026
#650) * feat(models): add DeepSeek-V4 PR1 skeleton with bit-exact reference parity Adds the foundational scaffolding for DeepSeek-V4-Pro support — a major architecture shift from V3.2 with mHC residuals, hybrid CSA+HCA attention, hash routing, and grouped output LoRA. PR1 ships the eager-mode model code with torch fallback kernels, validated against the official inference implementation at bit-exact parity (max_abs_diff = 0.0). Scope (PR1 only): - New atom/models/deepseek_v4.py: full Compressor / Indexer / Attention / Gate / Expert / MoE / Block / MTPBlock / ParallelHead / Transformer port (~1200 lines). Single-rank only; plain nn.Linear / nn.Embedding for now. - New atom/model_ops/sparse_attn_v4.py: torch fallbacks for sparse_attn and hc_split_sinkhorn (Sinkhorn-Knopp projection on Birkhoff polytope). - New atom/model_ops/quant_v4.py: torch fallbacks for FP8/FP4 inplace QAT round-trip and Walsh-Hadamard transform (replaces fast_hadamard_transform which doesn't build on ROCm). - Register DeepseekV4ForCausalLM in support_model_arch_dict. Out of scope (tracked for PR2-6): - Real HF checkpoint loading (PR2 = FP4 e2m1 loader, PR3 = TP + KV cache). - AITER sparse_attn kernel (PR4; spec at /app/logs_claude/aiter_v4_sparse_attn_spec.md, AITER team kicked off). - MTP integration with EagleProposer (PR5). - @support_torch_compile + CUDAGraph + openai_server (PR6). Verification: /app/logs_claude/v4_pr1_verify.py monkey-patches the reference's TileLang kernel imports with our torch fallbacks, copies the same dummy state_dict into both models, and runs prefill + decode side-by-side. 259 tensors match exactly; max_abs_diff = 0.0 on logits. * feat(quant_v4): add FP4 e2m1 -> BF16 dequant for V4 expert weights DeepSeek-V4-Pro stores routed expert weights as packed FP4 e2m1 (int8 with 2 values per byte, low nibble first) plus per-block ue8m0 scale (block size 32 along input dim). This commit adds `dequant_fp4_e2m1(packed, scale)` in atom/model_ops/quant_v4.py — a pure-torch unpacker that mirrors convert.py exactly but produces BF16 directly instead of repacking into FP8. Validated bit-exactly against an independent reference unpack on a real 22M-element expert tensor from the on-disk checkpoint. Also regression- tested across 5 different shapes/positions (w1/w2/w3 in first/mid/last layer + MTP). All produce values that lie exactly on the FP4 e2m1 grid. Scope: this is the standalone dequant utility. Wiring it into the model loader's safetensors pipeline + tying it to specific param names happens in PR3 alongside TP-aware expert sharding. Test: /app/logs_claude/v4_pr2_dequant_test.py Result: max_abs_diff = 0.0 (bit-exact) * refactor(deepseek_v4): swap BF16 projections to ATOM TP linear classes PR3a: replace nn.Linear / nn.Embedding with ATOM tensor-parallel-aware classes for the BF16 projections in Attention, Indexer, and the model embedding. Same `weight` parameter naming so dummy state_dicts continue to load. At TP=1 ATOM's tgemm.mm produces bit-identical output to F.linear, so PR1's reference parity (max_abs_diff = 0.0) still passes. Layers refactored (8 total): - DeepseekV4Model.embed: nn.Embedding -> VocabParallelEmbedding - DeepseekV4Attention.wq_a: nn.Linear -> ReplicatedLinear - DeepseekV4Attention.wq_b: nn.Linear -> ColumnParallelLinear - DeepseekV4Attention.wkv: nn.Linear -> ReplicatedLinear (single shared MQA head) - DeepseekV4Attention.wo_a: nn.Linear -> ColumnParallelLinear - DeepseekV4Attention.wo_b: nn.Linear -> RowParallelLinear (with all-reduce) - Indexer.wq_b: nn.Linear -> ColumnParallelLinear - Indexer.weights_proj: nn.Linear -> ColumnParallelLinear Deferred to later PRs (intentional): - Compressor.wkv/wgate (fp32) -> PR3c with quant_type wiring - ParallelHead.weight (fp32 LM head) -> PR3c - Expert.w{1,2,3} -> PR3b (FusedMoE wholesale rewrite) - MoE.gate.weight (used as raw Parameter, not Linear class) -> kept Verification: /app/logs_claude/v4_pr1_verify.py (now GPU mode with init_dist_env) shows max_abs_diff = 0.0 for prefill + decode against reference at TP=1. * feat(deepseek_v4): wire QuantizationConfig + implement load_weights() for real ckpt PR3c delivers end-to-end real-checkpoint loading for DeepSeek-V4 attention layers via ATOM's existing FP8/FP4 GEMM infrastructure. What works after this commit (validated on real /data/DeepSeek-V4-Pro/): - DeepseekV4ForCausalLM(atom_config) auto-builds a V4QuantConfig that maps routed-experts -> per_1x32 (FP4) and overrides wo_a / Compressor.wkv / Compressor.wgate / indexer.weights_proj -> bf16 (no quant). Everything else inherits the global FP8 (per_1x128) spec from the HF quantization_config. - load_weights(weights) walks an iterable of (name, tensor) pairs and: * Remaps ATOM's `weight_scale` -> on-disk `scale` naming. * Special-cases wo_a: dequantizes FP8+scale -> BF16 on the fly so the grouped-LoRA einsum (which aiter doesn't support in FP8) works. * Dispatches to ATOM Linear's weight_loader for FP8 / FP4 / BF16 paths. * Skips params with shape mismatch (e.g. expert nn.Linear waiting for PR3b's FusedMoE refactor) without crashing. - All 23 attention parameters (FP8 q/kv proj + FP4 indexer + BF16 wo_a + fp32 compressor) load successfully on real layer-2 of the V4 checkpoint. Threading changes: - DeepseekV4Args gains `quant_config: Optional[Any] = None`. - DeepseekV4Attention / Indexer / Compressor / Block / MTPBlock / DeepseekV4Model now accept `prefix: str = ""` and pass `quant_config + prefix` down to each ATOM Linear constructor so per-layer quant lookup works. Backward compatibility: - When `args.quant_config is None` (toy / dummy validation), V4QuantConfig retains its `QuantType.No` global — Linear layers stay BF16 and the PR1 bit-exact reference parity test (max_abs_diff = 0.0) still passes. Remaining gaps for end-to-end real-ckpt forward (tracked in design doc): - PR3b: replace MoE/Expert with FusedMoE so 384 expert FP4 weights load. - PR3d: refactor V4 attention.forward to accept 2D [num_tokens, dim] input (ATOM TP linears require 2D — current 3D path raises "GEMM not supported"). * refactor(deepseek_v4): switch forward to ATOM 2D flat-token convention PR3d adapts V4 model to ATOM's scheduler convention: model.forward consumes flat 2D `[num_tokens, dim]` tokens (single sequence implicit B=1), matching how ATOM's ModelRunner / scheduler pass tokens. This unblocks ATOM Linear's quantized GEMM kernels (which only accept 2D `[M, K]` input) and enables end-to-end real-checkpoint forward. What changed: - DeepseekV4Attention.forward(x, start_pos): now accepts 2D [num_tokens, dim]. Internally adds a B=1 dim only where needed (RoPE, sparse_attn). The grouped-LoRA einsum string changes from "bsgd,grd->bsgr" to "sgd,grd->sgr". - Compressor.forward / Indexer.forward: accept 2D x; auto-unsqueeze to 3D internally for backward compatibility with the existing logic. - Block.hc_pre / hc_post + ParallelHead.hc_head: refactored to be shape-agnostic in leading dims (use negative indexing on flatten / sum). Both 4D `[B, S, hc, D]` (legacy reference path) and 3D `[num_tokens, hc, D]` (ATOM path) work. - ParallelHead.get_logits: 2D path takes last token via `x[-1:]`; 3D path preserves `x[:, -1]` for legacy [B, S, D] inputs. - MTPBlock.forward: 2D-aware via `e.unsqueeze(-2)` for hc-dim broadcast. - DeepseekV4Model.forward: auto-flattens 2D `[1, S]` input_ids to 1D `[S]` for the new convention; rejects B>1 (proper multi-sequence batching needs attn_metadata, deferred). Validated: - PR1 reference parity (toy 4-layer dummy weights at B=1 S=32): max_abs_diff = 0.0 — still bit-exact after the 2D refactor. - PR3d end-to-end on REAL V4 weights: + Built DeepseekV4ForCausalLM (4 layers, real V4 dims, ~105B params) + load_weights() loaded 36 layer-2 params; 23/23 attn params nonzero + attn(x_2d=[16, 7168], start_pos=0) → output [16, 7168] bf16 + No NaN/Inf; output range [-2.94, 3.08], abs mean 0.42 (sensible) + This is the first successful V4 attention forward on real weights via ATOM Test scripts (under /app/logs_claude/): - v4_pr1_verify.py — toy parity (now uses B=1 + ATOM 2D path) - v4_pr3d_layer_e2e.py — real-weight 2D forward end-to-end - v4_pr3c_layer0_test.py — per-Linear validation against real ckpt Remaining for full model end-to-end: - PR3b: MoE → FusedMoE so 384 expert FP4 weights load (currently shape-skipped) - Multi-sequence support via attn_metadata (currently single-sequence implicit B=1) * feat(deepseek_v4): swap MoE to FusedMoE for 384-expert TP/EP loading PR3b enables ATOM's FusedMoE for V4's 384 routed experts so FP4 expert weights can load via the existing aiter `gemm_a4w4_quant` kernel and shard across TP/EP ranks. Also extends `select_experts` in moe.py to support V4's `sqrtsoftplus` scoring with `e_score_correction_bias`. Changes in atom/model_ops/moe.py: - `FusedMoE.select_experts` now handles `scoring_func="sqrtsoftplus"`: routing_weights = sqrt(softplus(router_logits)) + topk + renormalize. Mirrors the V4 reference Gate.forward exactly for non-hash layers. Changes in atom/models/deepseek_v4.py: - Dual-path MoE: when `quant_config` is set AND ATOM's global atom_config is initialized, MoE uses ReplicatedLinear gate + FusedMoE experts + ATOM-Linear shared_experts. Otherwise falls back to the original manual per-expert nn.Linear path so PR1 toy validation stays bit-exact (the reference test runs without ATOM's ModelRunner setting the global config). - Expert class accepts `quant_config + prefix`: when set, w1/w2/w3 become ColumnParallelLinear/RowParallelLinear (FP8 path); else nn.Linear (toy). - DeepseekV4ForCausalLM.get_expert_mapping() returns the (param_name, weight_name, expert_id, shard_id) tuples mapping V4's `w1/w2/w3` ckpt names to FusedMoE's merged `w13_*`/`w2_*` params. - load_weights() walks expert_mapping first to dispatch routed expert tensors via FusedMoE's per-expert weight_loader, then handles the rest: * ATOM `weight_scale` ↔ on-disk `scale` rename (existing) * ATOM `gate.e_score_correction_bias` ↔ on-disk `gate.bias` rename (NEW) * `wo_a` FP8 → BF16 dequant on load (existing) Validated: - PR1 toy parity: max_abs_diff = 0.0 (manual MoE path still bit-exact). - PR3d e2e: real layer-2 attn + 2D forward still works. - PR3b new: under stub atom_config, FusedMoE path activates correctly. Layer-3 (non-hash, real V4 dims): gate + e_score_correction_bias + shared_experts (6/6) loaded; FusedMoE expert mapping returns 1152 entries (384 experts × {w1,w2,w3}). Known limitations (deferred): - Hash routing (layers 0/1/2): tid2eid table is loaded but routing logic still falls through to sqrtsoftplus path → INCORRECT for hash layers. Proper hash routing requires either a custom path through FusedMoE or a pre-computed (topk_weights, topk_ids) injection point. - Multi-sequence batching via attn_metadata (currently single-sequence implicit B=1). Test: /app/logs_claude/v4_pr3b_fusedmoe_test.py * fix(deepseek_v4): V4QuantConfig now matches FusedMoE's bare 'experts' prefix Bug: `make_v4_quant_config` matched `"ffn.experts." in layer_name` (with trailing dot). FusedMoE.__init__ asks for the layer's quant_type with prefix `layers.N.ffn.experts` (NO trailing dot — it's the parent module of the per-expert weights, not a per-expert lookup). The check failed, so FusedMoE inherited the global FP8 (per_1x128) spec and allocated the routed expert weights as `float8_e4m3fn` instead of `float4_e2m1fn_x2`. Symptom in PR3b validation output before the fix: FusedMoE experts: 3/5 nonzero (loader couldn't dispatch FP4-shaped on-disk tensors into FP8-typed model params; shape mismatch silently skipped them) After the fix: experts.w13_weight: (385, 6144, 3584) torch.float4_e2m1fn_x2 ✓ experts.w13_weight_scale: (385, 6144, 224) torch.float8_e8m0fnu ✓ experts.w2_weight: (385, 7168, 1536) torch.float4_e2m1fn_x2 ✓ experts.w2_weight_scale: (385, 7168, 96) torch.float8_e8m0fnu ✓ e_score_correction_bias: (384,) torch.float32 ✓ Match condition tightened to `".ffn.experts" in layer_name` so it catches BOTH `layers.N.ffn.experts.M.w1` (per-expert Linear lookups) AND `layers.N.ffn.experts` (FusedMoE parent module lookup). Note: a separate aiter-side issue (HSA_STATUS_ERROR_EXCEPTION on FP4 expert weight_loader, traced to a `direct_copy_kernel` with grid size exceeding HW limits) prevents end-to-end FP4 expert load testing on this box. The dtype/shape correctness above is verified by inspecting the constructed module's params directly. Validated: - PR1 toy parity: max_abs_diff = 0.0 (manual MoE fallback unaffected) - PR3d real-attention forward: still works * fix(deepseek_v4): correct FusedMoE expert weight + scale + bias dispatch PR3b's expert weight loader had three bugs that caused weights to load as zero or be silently dropped: 1. **Expert mapping pattern mismatch**: `make_expert_params_mapping` returns `(param_part="experts.w13_", weight_part="experts.0.w1.", ...)` — substring substitution, not endswith. The old code built `f".experts.{e}.{suffix}"` which never matched. Switched to longest-prefix substring substitution matching the standard ATOM loader pattern. 2. **Scale dtype zero-fill**: copying `torch.float8_e8m0fnu` into a `uint8` destination via `copy_()` silently produces zeros (mismatched dtype, no reinterpret). FusedMoE allocates `w13_weight_scale` as uint8; force a `.view(torch.uint8)` on the e8m0 source before passing to the loader. 3. **Param suffix `_scale` vs `.weight_scale`**: after substring sub, `experts.0.w1.scale` becomes `experts.w13_scale`, but the FusedMoE param is `experts.w13_weight_scale`. Added `_scale` → `_weight_scale` post-fix. Plus: gracefully slice on-disk gate.weight / gate.bias when the test caps n_routed_experts below the checkpoint size (no-op in real serving). Verified: - v4_pr3b_fusedmoe_test: 32 params loaded, 5/5 expert + 6/6 shared nonzero - v4_pr3d_layer_e2e: real attention forward still works - v4_pr1_verify: bit-exact reference parity preserved (0.0 max diff) * feat(deepseek_v4): wire hash routing for first 3 layers via custom_routing_function V4 uses tid2eid hash lookup (instead of gate-logit topk) for routing in layers where compress_ratio implies hash layer (first 3 layers in standard config). Previously, MoE just declared tid2eid for weight loading but inference fell through to sqrtsoftplus path → wrong routing for those layers. This commit: - Adds an early `custom_routing_function` branch to FusedMoE.select_experts (it was in the signature but never honored — the non-grouped path went straight to scoring_func dispatch). Now any non-None custom fn takes precedence and returns (topk_weights, topk_ids). - Adds DeepseekV4MoE._hash_topk(): topk_ids = tid2eid[input_ids], topk_weights = sqrtsoftplus(router_logits) gathered + renormalized. Stashes input_ids on self before the experts() call so the closure can index tid2eid; clears immediately after. - For hash layers: assigns experts.custom_routing_function = self._hash_topk in MoE.__init__ so FusedMoE picks it up via the moe_forward custom op → forward_impl_graph → quant_method.apply → select_experts plumbing. Verified: - PR3e (new): synthetic tid2eid → _hash_topk produces exact expected ids, renormalized weights match reference math (max_abs_diff = 0.0) - PR3e: FusedMoE.select_experts honors custom_routing_function correctly - PR1 toy parity: still 0.0 max diff (hash path is opt-in via is_hash_layer) - PR3b FusedMoE load: 32 params, all nonzero (no regression) - PR3d real attn forward: still works (non-hash layer) * feat(deepseek_v4): full Block.forward (attn + FusedMoE) end-to-end on real ckpt Three changes converging on the first working V4 layer forward: 1. **weights_mapping**: Add class-level rename dict so the standard ATOM loader (`atom.model_loader.loader.load_model`) can ingest V4 ckpt names without per-model loader.py changes. `.gate.bias` → `.gate.e_score_correction_bias`, `.scale` → `.weight_scale_inv`. Loader's built-in `weight_scale_inv` → `weight_scale` rename then completes the path. Real serving via ModelRunner now works for non-wo_a layers. 2. **process_weights_after_loading hook**: After my custom `model.load_weights` finishes copying tensors, walk all submodules and call `quant_method.process_weights_after_loading(layer)` (or `layer.process_weights_after_loading()` if no quant_method). Without this, FusedMoE's `shuffle_weights` step is skipped and the FP4 ck_moe kernel reads stale weight layout — manifested as HSA_STATUS_ERROR_EXCEPTION mid-forward. Standard loader.py calls this for us; my custom loader had to replicate it. 3. **PR3f end-to-end test** (logs_claude/v4_pr3f_block_e2e.py): - Build 1 dense layer (compress_ratios=[0]) with 8 routed experts - Load real layer-3 weights (32 target params, 33/33 nonzero) - Build mHC residual `[8 tokens, hc_mult=4, dim=7168]` - Call Block.forward(x, start_pos=0, input_ids) - Output: shape preserved, range [-4.1, 4.6], abs mean 0.81, no NaN/Inf This is the first end-to-end forward through V4's full layer: attention (FP8 wq/wkv + BF16 wo grouped LoRA + indexer) + FusedMoE (FP4 experts via aiter ck_moe + sqrtsoftplus routing + bias correction + shared expert) + mHC pre/post Sinkhorn projections. Confirmed no regression on PR1/PR3b/PR3d/PR3e. * feat(deepseek_v4): standard ATOM loader (load_model) now handles V4 ckpts ModelRunner uses atom.model_loader.loader.load_model() — not the model's custom load_weights(). This commit closes that gap so real serving via openai_server works end-to-end: 1. **Expand weights_mapping with prefix renames**: V4 ckpt has bare names (`embed.`, `layers.`, `norm.`, `head.`, `hc_head_`) but our params live under `self.model = ...`. Add prefix substitutions so the loader's `model.get_parameter(name)` lookup hits the right attribute path. 2. **Fix dtype-mismatch silent zero in FusedMoE._load_w13/_load_w2**: PyTorch's `tensor.copy_()` between mismatched float8/uint8 dtypes silently writes zeros. V4's per-1x32 weight scales are stored as `float8_e8m0fnu` on disk but FusedMoE allocates them as `uint8` (raw byte storage). Force a `.view(torch.uint8)` reinterpret on the source so the bytes round-trip correctly. This is a pre-existing bug that was masked because V2/V3 use `float32` scales — V4 is the first ATOM model to use e8m0/e4m3 scales. Verified: - PR3i (new): standard load_model() loads V4 layer-0 from full 805GB ckpt index — 43/43 model params nonzero (100%), 5GB selective load. - PR3g (new): full Model.forward(input_ids) → logits on real ckpt. Output shape (1, 129280), range [-14.2, 15.4], std 3.05, no NaN/Inf. - PR3h (new): hash layer (layers 0/1/2) Block.forward works on real layer-0 ckpt (tid2eid loaded, 773423/775680 nonzero entries, real per-token expert assignments diverge from default sqrtsoftplus path). - All 5 prior tests (PR1/PR3b/PR3d/PR3e/PR3f) still pass — no regression. Net result: V4 inference pipeline is now production-ready for real ckpt loading + forward; remaining gap is multi-layer + multi-batch attn metadata + AITER sparse_attn (parallel work). * fix(deepseek_v4): wo_a FP8 dequant via process_weights_after_loading hook PR3i shipped "100% nonzero params" but never ran forward through the standard-loader path. Verifying with PR3j (new) revealed wo_a values were 2768× too large — `torch.copy_(BF16_dst, FP8_src)` does an FP8→BF16 dtype conversion but SKIPS the per-128-block scale multiplication. Result: raw FP8 e4m3 max value (448.0) lands in the BF16 weight buffer instead of the true ~0.04 attention-init magnitude. Fix: stop forcing wo_a to no_spec/BF16 in V4QuantConfig. Let it allocate as FP8 ColumnParallelLinear so the standard FP8 loader fills both `wo_a.weight` (FP8) and `wo_a.weight_scale` (e8m0) correctly. Then DeepseekV4Attention.process_weights_after_loading dequants in place, replacing weight with BF16 + dropping the scale param. Forward continues to use BF16 weight in the grouped LoRA einsum (aiter has no FP8 grouped einsum). Also removes the manual wo_a special-case from custom load_weights() — both load paths (custom + standard) now converge through the same process_weights_after_loading dequant. Verified by PR3j parity test: - Custom path wo_a: abs.mean=0.0214, abs.max=0.4062 - Standard path wo_a: abs.mean=0.0214, abs.max=0.4062 (BIT-EXACT) - Standard-loader Model.forward → logits range [-17.9, 15.8], std 3.04 - Magnitude ratio: 1.00 (was 2768× before fix) - All 9 tests pass — no regression. This was a silent corruption that PR3i's "params nonzero" check missed. The lesson: nonzero != correct. Always verify with forward. * feat(deepseek_v4): end-to-end inference with triton MoE and swiglu_limit Major changes enabling correct V4 inference (single-prompt verified with 512-token coherent output in both English and Chinese): Model fixes: - WeightsMapper prefix-anchored remapping (fixes 381 silently-skipped params) - wo_a FP8→BF16 dequant with quant_type=No to prevent CK shuffle corruption - Hash routing (first 3 layers) now applies route_scale=2.5 - shared_experts reduce_results=False + unified all_reduce in MoE.forward - KV cache reset on start_pos=0 with score_state=-inf initialization - TP-correct head/group counts for Attention and Indexer MoE routing: - Standard Silu activation (not Swiglu — aiter a16w4+Swiglu has 9× amplitude loss on gfx950). swiglu_limit clamping done in triton post-kernel. - ATOM_USE_TRITON_MOE=1: triton matmul_ogs path with swiglu_limit clamp - ATOM_V4_TORCH_MOE=1: per-expert torch fallback with FP4 dequant (slow) - GFX950MXScaleLayout→CDNA4MXScaleLayout fix in fused_moe_triton.py Loader improvements: - WeightsMapper auto-read from model class attribute - Post-load WARNING listing all unloaded params - Shape-mismatch raises RuntimeError instead of silent skip Config: - deepseek_v4→deepseek_v3 registry mapping with V4 field re-injection - Robust from_hf_config with getattr defaults Known limitations: - Single-sequence only (kv_cache[:1,...] hardcoded); batch>1 needs PR3 - Multi-request KV isolation pending scheduler integration - TPOT ~213ms with --enforce-eager (no CUDAGraph) * fix(deepseek_v4): apply swiglu_limit to shared_experts (upstream a1fd202) Upstream ref (deepseek-ai/DeepSeek-V4-Pro@a1fd202) changed shared_experts from no swiglu_limit to swiglu_limit=args.swiglu_limit, making it consistent with routed experts. * refactor(deepseek_v4): wire positions tensor through forward chain; switch RoPE to aiter - DeepseekV4ForCausalLM/Model/Block/MTPBlock/Attention/Compressor/Indexer now accept `positions: torch.Tensor` instead of `start_pos: int`; internal ring-buffer indexing still derives `start_pos = positions[0].item()` (full per-request KV slot management deferred to PR3). - New `_V4RoPE` wraps aiter `rope_cached_positions_{,2c_}fwd_inplace`, driven by per-token positions. Cos/sin cache built via V4's exact YaRN math (`_precompute_freqs_cis`); kept symmetric to `_apply_rotary_emb` by working on the pre-sliced rope tail. - `_build_cos_sin_cache` is lru-cached on (rope params, dtype, device) so the 3 distinct rope param sets (HCA / CSA / Dense) share one GPU tensor across all 62 layers instead of 62 register_buffer copies (~16 GB OOM otherwise). - Inverse RoPE on the attention output keeps `_apply_rotary_emb` (aiter has no inverse kernel); the complex freqs slice is rebuilt on demand from the cos/sin cache via `_V4RoPE.freqs_for_positions`. - Verified: simple_inference single-prompt CN 256 tokens coherent. * refactor: delegate ATOM KV cache subsystem to attention builders Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner is now blind to attention type — it walks modules and dispatches; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder. ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching scheduler / block_manager / ModelRunner. New AttentionMetadataBuilder hooks (defaults are no-ops): - compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool - allocate_per_req_cache(num_slots) dict of named per-request state tensors - compute_block_bytes() per-block bytes for the KV pool budget - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV cache tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store) - build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache) Builder overrides: - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state. Naming distinguishes group (per-request unit) from slot (raw tensor index). One group occupies `slots_per_req()` contiguous slots in the underlying tensor: Sequence.mamba_state_slot -> .per_req_cache_group seq.mamba_enabled -> .has_per_req_cache batch.mamba_state_slots -> .per_req_cache_groups BlockManager.mamba_* -> .per_req_cache_* (free pool, accounting) config.mamba_equiv_per_req -> .per_req_cache_equiv_blocks config.num_mamba_groups -> .num_per_req_cache_groups ModelRunner.max_mamba_slots -> .max_per_req_cache_slots (tensor dim) Removed (moved to builders): ModelRunner._compute_mamba_per_slot_bytes ModelRunner.gated_delta_net_state_shape / _dtypes Sanity check: ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag. Verified: - tests/test_per_req_cache_decoupling.py: 24/24 pass - core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent): flexible-extract = 0.8757 +/- 0.0091 (baseline 0.8711 from #602) strict-match = 0.8605 +/- 0.0095 * style: black format block_manager.py * feat(deepseek_v4): per_req_cache abstraction (pre2a + pre2c-A) V4 backend (DeepseekV4Backend + DeepseekV4AttentionMetadataBuilder) plus migration of state-cache buffers to ATOM's per_req_cache pool: - pre2a: 6 Compressor state buffers (kv_state + score_state for CSA Main / CSA Indexer / HCA Main). - pre2c-A: SWA window per layer (paper §3.6.1 state cache, every layer has SWA branch in V4-Pro). Attention.kv_cache splits into Attention.swa_kv (per_req_cache) + Attention.kv_cache (compressed entries only, still register_buffer; pre2c-B will move under block_table). Validated single-prompt 64-token Chinese generation (V4-Pro tp=8, triton MoE, enforce-eager) — output indistinguishable from baseline. * feat(deepseek_v4): classical KV cache via block_table (pre2c-B) Strict-paper §3.6.1 split: compressed entries (CSA Main, CSA Indexer, HCA Main) move from per-layer register_buffer to block-table-indexed pools owned by DeepseekV4AttentionMetadataBuilder. - block_size = lcm(m, m') = 128 original tokens, plumbed via Config override on model_type=deepseek_v4 detection. - Three classical pools: v4_csa_main_kv [num_blocks, n_csa, k1=32, head_dim=512] v4_csa_idx_kv [num_blocks, n_csa, k1=32, idx_head_dim=128] v4_hca_main_kv [num_blocks, n_hca, k2=1, head_dim=512] Per-layer slice bound to Compressor.kv_cache / Indexer.kv_cache. - V4 model adds _v4_scatter_compressed / _v4_gather_compressed helpers and fetches block_table from forward_context. Compressor.forward scatters writes into block-table slots; Indexer.forward + decode sparse_attn input gather committed entries from blocks. - Indexer + 1-slot warmup fallback register_buffer pattern same as pre2a Compressor.kv_state. - Attention.kv_cache attribute removed entirely (compressed entries no longer co-located on the Attention module). Validated single-prompt 64-token Chinese generation (V4-Pro tp=8) unchanged from pre2c-A baseline. * feat(deepseek_v4): multi-sequence forward dispatch (PR3-main) V4 forward now handles ATOM ragged-batch input with per-seq slot + block_table routing. Single-seq behavior unchanged; concurrent batched multi-seq prefill + decode verified end-to-end on 4 prompts. Changes: - Builder prepare_decode/prepare_prefill populate cu_seqlens_q, block_tables, and v4_slot_indices (new per-seq metadata attached to AttentionMetaData via dynamic attribute). - _v4_get_block_table replaced with _v4_get_seq_metadata returning (block_tables, slot_indices, cu_seqlens_q, num_seqs). - Compressor.forward + Indexer.forward signatures: add slot, block_table args. Per-slot indexing via [slot:slot+1, ...] replaces hardcoded [:1, ...] / [:bsz, ...]. - Attention.forward: batched Linear projections + RoPE on full flat tensor; per-seq loop slices (cu_seqlens_q) and dispatches SWA write, Compressor scatter, Indexer + sparse_attn with each seq's slot + block_table. Per-seq state-cache reset on prefill (start_pos==0) only zeros that seq's slot — no cross-seq pollution. - ParallelHead.get_logits: pick last-token-per-seq via cu_seqlens_q (fixed long-standing single-seq assumption that always returned only x[-1] regardless of batch size). Validated MAX_NUM_SEQS=4 concurrent batched inference: 4 prompts processed in parallel produce independent coherent outputs. * fix(deepseek_v4): correct ue8m0 input quant + MoE routing scale Three independent bugs caused V4 to ramble on edge-confidence prompts (e.g. "1+2+3=?" output garbled despite 3/4 batch=4 prompts looking OK). Single-prompt output now matches reference byte-equal on the first 5 tokens and produces "The sum is: 1 + 2 + 3 = **6**." (was: "I'll happily provide a step-by-step breakdown..." ramble). Bug 1 (quant_v4.py) — act_quant_inplace ue8m0 path used `ceil(log2)` (matched TileLang reference) but ref_full_generate.py and aiter both use round-to-even via f32_to_e8m0/e8m0_to_f32. The 1-binade gap appeared as ~0.002 cos drift on KV path, accumulating across 60 layers. Bug 2 (moe.py) — FusedMoE.select_experts sqrtsoftplus path renormalized topk_weights but never applied `* routed_scaling_factor`. The hash routing path (V4 layers 0-2) does this internally, hiding the bug for hash layers. Reference Gate.forward (model.py:583) applies the multiply for every non-softmax routing path. Without the scale, layer 3+ MoE outputs were off by 1.5x, producing the visible cos jump from 1.0 (layer 0/2) to 0.98 (layer 3+). Bug 3 (deepseek_v4.py) — DeepseekV4Args.from_hf_config did not read scale_fmt; HF config.json doesn't carry the field, only inference/config.json does. Default to "ue8m0" matching reference ModelArgs (inference/model.py:40) so act_quant_inplace's ue8m0 path is actually exercised. Also folds in previously-validated V4 cleanups that were sitting in the working tree: - _RMSNorm → ATOM RMSNorm (mark_trace + torch.compile friendly) - Indexer wq_b/weights_proj: ColumnParallelLinear → ReplicatedLinear (matches sglang/upstream; avoids extra all_reduce on index_score) - Block.hc_post defaults to torch (aiter mhc_post drift, opt-in via V4_AITER_HC_POST=1; see notes/12) - _torch_moe_forward: ue8m0 round-trip on input to mirror reference Expert.forward (act_quant before fp4_gemm), gated by V4_USE_REF_QUANT=1 Diagnosis path: notes/14_debug_1plus2plus3.md → notes/19_full_fix_verified.md * feat(debug_helper): generic env-gated dump / compare / ref-patch + V4 cleanup New module atom/utils/debug_helper/ provides reusable primitives for forward bisecting and batch-invariance investigation. All entry points are no-ops when their controlling env var is unset, so they are safe to leave wired into production paths (model_runner.py post-load). Components - dump.py install_block_forward_hooks (multi-class + multi-call), maybe_dump_weights_and_exit, maybe_log_topk - compare.py cos_max (DOUBLE precision — fixes fp32 cos > 1.0 bug), slot_split, compare_slots, pick_prefill_call, schema_diff, plus CLI subcommands: slot-invariance / ref-vs-target / layer-bisect / schema - ref_patch.py patch_method / patch_block_forward / patch_module_dump context managers for instrumenting read-only references - 9 ATOM_FWD_DUMP_* / ATOM_WEIGHT_DUMP_* / ATOM_DEBUG_TOPK env vars registered in atom/utils/envs.py "Debug Dump" section Wired into model_runner.py with a 3-line post-load call (no-op default). V4 model cleanup - Convert all nn.Parameter() constructors in deepseek_v4.py to atom_parameter() so inference-vs-training grad behavior is controlled from a single place (ATOM_REQUIRES_GRAD env). 21 call sites. Documentation - docs/environment_variables.md: new "Debug Dump" subsection documenting all 9 env vars + CLI usage. - .claude/skills/dump-bisect-debug.md (v3.0): full methodology rewrite in English with quick-start decision tree, phase-at-a-glance summary, "When to stop / accept divergence" guidance, V4 paper §3.3 batch invariance treatment as Phase 8. Includes Bug 11 isolation case study. - .claude/skills/atom-patterns.md: ATOM architecture index reference. Verified by running CLI on existing E1 4xP3 dump: python -m atom.utils.debug_helper.compare slot-invariance \\ --dir /app/logs_claude/deepseek_v4/dumps/bug11_e1 reproduces the layer-by-layer divergence table that informed Bug 11 isolation in notes/21_bug11_isolation.md. * fix(weight-loading): bidirectional coverage check + V4 hash-layer bias Two fixes that surfaced from the same V4 load run: 1. atom/models/deepseek_v4.py — skip `gate.e_score_correction_bias` allocation for hash-routed layers (layer_id < n_hash_layers). V4 hash layers route via `tid2eid` lookup, not bias-corrected gate logits; the checkpoint has no `gate.bias` for those layers (only layers >= 3). Allocating it caused 3 spurious "param NOT loaded from checkpoint" warnings every load. Both call sites that read the attribute now use `getattr(self.gate, "e_score_correction_bias", None)` — moe.py already accepts None for `e_score_correction_bias`. 2. atom/model_loader/loader.py — add ckpt-side coverage check (the reverse direction of the existing atom-side check). Every `get_parameter() except AttributeError: continue/break` site now records `(orig_ckpt_name, rewritten_name)`; after the main loop the loader warns if any non-benign drops occurred. This catches the actionable bug class — `weights_mapping` / `WeightsMapper` rewrites the ckpt name to something the model has no slot for, silently throwing away real weight data — which the existing atom-side check misses entirely. Benign families (output_scale / kv_scale / inv_freq / weight_scale_2) are filtered so the warning is signal, not noise. Verified on V4 load: - atom-side warning: 46/2519 -> 43/2516 (3 hash bias removed) - ckpt-side warning: 0 drops (mapping is clean for V4) - remaining 43 are all model.mtp.0.* (PR5 todo) * feat(deepseek_v4): pos%(2*ratio) ring buffer for Compressor state cache Per paper §3.6.1, the Compressor's per-request state cache holds "uncompressed tail tokens + previous block as B-side overlap context" (eq 11). Restructure ATOM's kv_state from a roll-on-decode two-segment buffer into a single pos % STATE_SIZE ring buffer (STATE_SIZE = 2*ratio for overlap CSA, ratio for HCA). Kernel update_compressor_states (atom/model_ops/v4_kernels/state_writes.py): - dst = pos % STATE_SIZE for every token; no segment switching, no roll - Phase derived in-kernel from context_lens vs cu_seqlens_q; no IS_PREFILL - Write mask: fresh prefill keeps [max(0, cutoff-ratio), seqlen) (B-side overlap + tail); decode/MTP writes every token Compressor.forward: - Drops decode-boundary roll (kv_state[:ratio] <- kv_state[ratio:]) - Reads A-side / B-side halves by block-id parity (comp_id % 2) Metadata plumbing: - V4 prepare_decode now populates var["context_lens"] + attaches to AttentionMetaData (parent prepare_prefill already did) - Compressor / Indexer.forward accept required context_lens kwarg - Wrapper has no positions-derived fallback for context_lens Also bundles PR-A scaffolding: - ATOM_V4_BACKEND env gate + per-layer bisect (envs.py, v4_backend_gate.py) - CPU-mirror metadata (cu_seqlens_q_cpu, state_slot_mapping_cpu, start_pos_per_seq_cpu) to avoid per-seq .tolist()/.item() syncs - v4_slot_indices -> state_slot_mapping rename (clearer vs paged-KV slot_mapping) - swa_write Triton kernel integration (Phase 1a) under backend gate Validates: 15/15 byte-equal kernel-vs-reference (prefill + decode + MTP); simple_inference fast path TPOT 0.328-0.518s/tok matches pre-refactor baseline (Apr 29 v4_simple_inference.log: 0.453s/tok). * feat(deepseek_v4): fused_compress_attn kernel + start_pos-free interface Replaces the per-source-position Python pool/RMSNorm/RoPE/scatter chain in Compressor.forward with a single fused Triton kernel that handles fresh prefill, chunked prefill, single-token decode, and MTP-N decode uniformly via per-source-position dispatch. Key changes: * atom/model_ops/v4_kernels/fused_compress.py (new): - Fused softmax-pool + RMSNorm + GPT-J RoPE + bf16 kv_cache scatter. - Grid sized by start_pos-free upper bound n_max=(token_num+ratio-1)//ratio; excess programs early-exit (<=ratio-1 per launch). - Kernel loads start_pos and end_pos from positions[0] / context_lens[0] itself — no caller-supplied start_pos, no CPU boundary enumeration, no .item() sync at this site. - Output [1, n_max, head_dim] padded; downstream sparse_attn is gather-based and never reads padded rows. - K-loop uses tl.range (NOT tl.static_range) — HCA layers (ratio=128) would otherwise expand to 148KB hsaco vs 16KB for CSA (ratio=4), making short-prefill HCA cases that early-exit prohibitively slow due to per-launch overhead scaling with hsaco size. - Pure-PyTorch reference impl with same padded contract for parity tests. * atom/model_ops/v4_kernels/state_writes.py: - update_compressor_states: unified write mask (preserve last STATE_SIZE absolute positions of this fwd) replaces old prefill/decode split. Same invariant covers fresh prefill, chunked prefill, single decode, MTP-N. * atom/models/deepseek_v4.py: - Compressor.forward: drop start_pos parameter and .item() fallback; pass context_lens to fused_compress_attn (kernel derives end_pos). - Indexer.forward: drop start_pos= arg in inner self.compressor() call; keeps own start_pos param for mask logic. - DeepseekV4Attention.forward per-seq dispatch: - Fix decode n_committed = end_pos // ratio (was (start_pos+1)//ratio, which under-counted boundaries committed within MTP-N window). - Rename per-fwd token-count locals seqlen -> token_num across Compressor.forward, Indexer.forward, and per-seq dispatch loop. Validation: - Unit parity test (kernel vs reference) passes 0 max_diff across 21 cases: fresh prefill / chunked prefill / single decode / MTP-N / empty-boundary corner cases. - simple_inference 4-prompt e2e completes with default max_tokens=256 in ~120s (baseline ~131s); outputs are coherent across English and Chinese prompts. Follow-up: batched fused_compress_attn (one launch per layer instead of per-seq) tracked in /app/logs_claude/deepseek_v4/notes/24_*.md. * feat(v4): replace weight-free RMSNorm with fused Triton, ~1.6% TTFT improvement, ~2% TPOT and latency improvement on long-sequence decode * feat(deepseek_v4): use triton sparse attn kernel and move attn kernel out of loop (#678) * use triton sparse_attn_ragged * use triton sparse_attn_ragged_varlen * fix(sparse_attn_v4): BLOCK_H=16 for ROCm MFMA lowering block_h=2 (or 4) made tl.dot operands smaller than the smallest bf16 MFMA tile (16x16x16) on gfx9xx/gfx950. TritonAMDGPUOptimizeDotOperands crashed the pass pipeline ("PassManager::run failed") instead of falling back to FMA, breaking V4 e2e on AMD: all three kernels (_sparse_attn_triton, _sparse_attn_ragged_triton, _sparse_attn_ragged_varlen_triton) failed at JIT compile time. Bumping block_h to 16 in all three wrappers fixes the crash. Numerical parity vs torch reference is unchanged (mean abs diff 4.9e-4 vs torch 1.7e-4, both within bf16 attention noise on D=512). * feat(deepseek_v4): SGLang-style packed plan tensors for batched compressor dispatch Replace the per-seq Python loop that launched 64-layers x num_seqs separate fused_compress_attn / update_compressor_states kernel calls per fwd with a single batched kernel call driven by SGLang-style 16B plan rows [ragged_id, batch_id, position, window_len]. Key changes: - atom/model_ops/v4_kernels/compress_plan.py (new): vectorized numpy plan generator. For each (ratio, is_overlap) pair, emits one compress_plan (boundaries where (pos+1)%ratio==0) and one write_plan (last STATE_SIZE positions for state cache update), packed in B-flat layout. Plus cu_compress_cpu prefix sums for caller slicing. - v4_kernels/fused_compress.py: kernel takes plan_ptr instead of positions/context_lens/slot. window_len = K - min(j_in_seq+1, K) replaces the old s >= start_pos test; in_row = ragged_id - (K-1-k_static) for ragged input rows. Output is now tightly packed [num_compress, head_dim], not padded n_max. - v4_kernels/state_writes.py: kernel takes write_plan_ptr instead of positions/context_lens/slot. Per-prog: load (ragged_id, batch_id, position) and write at dst = position % STATE_SIZE. No in-kernel mask (host pre-filtered). - attentions/deepseek_v4_attn.py: builds compress_plans dict in prepare_prefill / prepare_decode, attached to attn_metadata. PR #678's _attach_sparse_layout_metadata now reads cu_compress_cpu from plan instead of using the ceil(n_max) upper bound formula. - models/deepseek_v4.py: Compressor.forward / Indexer.forward consume CompressPlan; DeepseekV4Attention.forward calls self.compressor() once outside the per-seq loop; per-seq loop slices kv_compress via cu_compress_cpu and concatenates to seq_kv. Indexer wraps as bs=1 plan via make_single_seq_plan (indexer batched dispatch is a separate PR). Also fixes n_committed formula for MTP-N decode: (start_pos + token_num) // ratio (was (start_pos + 1) // ratio, which dropped boundaries inside the MTP window when token_num > 1). Validated: 30-case parity test (single-seq, batched bs=4/8, MTP-3, HCA) all pass with max_diff <= 4.77e-7. V4 e2e (4 prompts x 256 tokens, GSM8K 25-sample smoke at 0.88) confirms no regression. * feat(deepseek_v4): batch state-cache reset/write/topk (Phase 1+2a+2c) Three independent batched-ops phases that share an outer-loop slot in DeepseekV4Attention.forward: - Phase 1: drop redundant per-seq state-cache reset loop. Fresh prefill never reads stale swa_kv (raw seq_kv used directly) nor stale Compressor state cache (fused_compress K-loop's is_padding=s<0 masks all is_state reads when prefix=0 → s = j_in_seq - K + 1 + k_static < 0 for every is_state iteration). Verified GSM8K=0.96 on 25/50 samples. - Phase 2a: vectorize per-seq window topk into one batched _build_window_topk_batched producing [total_tokens, win] padded with -1; loop body slices to per-seq width matching legacy _get_window_topk_idxs shape. - Phase 2c: hoist SWA write out of per-seq loop into one batched swa_write kernel call. Pre-filter to last-win tokens per seq so the num_tokens parallel programs never collide on the same swa_kv ring slot (pos%win). Pre-fix, long-prefill (token_num > win) caused intra-seq write-race that dropped GSM8K from 0.88 to 0.32. Per-seq dispatch loop still runs for Indexer + kv_sa packing — those batched in follow-up phases (2b/2d/2e). * feat(deepseek_v4): hoist Indexer Compressor out of dispatch loop (Phase 2b-i) Move the per-seq Indexer Compressor call into a single batched call before the dispatch loop, using the same batched plan as the main CSA Compressor (both have ratio=4, overlap=True, identical geometry). The Indexer's internal kv_cache + state cache are populated for the whole batch in one launch instead of bs separate launches per layer. Indexer.forward gains a `skip_inner_compressor=True` flag the dispatch loop sets after the hoist; legacy bs=1 plan path remains as the fallback for any other caller. Per-seq cost reduction: 64 layers × bs Compressor launches drop to 64 layers × 1 (each Compressor launch fires wkv/wgate Linear + fused_compress_attn + update_compressor_states). Verified GSM8K=0.94 ± 0.034 on 50 samples (matches baseline 0.94 — earlier 0.88 reading on 25 samples was within natural ±2-sample noise). * feat(deepseek_v4): use fp8_mqa_logits in Indexer score+topk (Phase 2b-ii a) Replace per-seq BF16 einsum (q ⊗ K → relu → weight → sum) with aiter's fp8_mqa_logits kernel. Mathematically identical (relu(QK*kv_scale) * weight summed over heads), but executes as a single FP8 mma per row + post-row mask + topk. Mirrors V3.2's sparse_attn_indexer_prefill kernel call. Q is FP8-quantized inline (per-row 1x128 scale via get_hip_quant); the scale is folded into `weights` along with softmax_scale and 1/sqrt(H), matching the V3 convention. K is FP8-quantized after the per-seq gather. cu_starts=0, cu_ends=(pos+1)//ratio express the V4 ratio-aware causal frontier directly through the kernel's per-row KV range — no extra masking pass needed. The legacy BF16 einsum path is retained behind `ATOM_V4_INDEXER_FP8=0` for A/B testing. Verified GSM8K=0.96 ± 0.028 on 50 samples (baseline 0.94 ± 0.034 — fp8 path is statistically at-or-above baseline; FP8 quant is closer to V4 training distribution than the current BF16 fallback). * feat(deepseek_v4): Phase 3 hoist per-fwd metadata + comprehensive cleanup Hoist all per-fwd, layer-invariant work from V4Attention.forward and Indexer.forward_batched into the metadata builder, eliminating ~1200 per-layer torch.as_tensor H2D copies (~14 per pack call * 60+ layers, ~9 per Indexer call * 30 CSA layers, ~3 per gather call * 60+ layers) in production fast path. Builder-side helpers (atom/model_ops/attentions/deepseek_v4_attn.py): - _attach_v4_per_fwd_meta: window_topk_batched + SWA write/positions/slots - _build_v4_pack_meta_for_ratio: kv_sa + topk_flat index tensors per ratio - _build_v4_indexer_meta: CSA Indexer batch_id/cu_committed/k/offset/is_prefill GPU tensors plus layer-invariant cu_starts/cu_ends/visible_end/width_mask/ future_threshold derivations - _build_v4_gather_indices: precomputed batch_ids/block_in_seq/slot_in_block for _v4_gather_compressed_batched - _populate_state_slot_mapping: warmup fallback to slot 0 so dummy_run takes the normal forward path V4Attention.forward / Indexer.forward_batched refactor: - Read all per-fwd state once at top of forward (one get_forward_context call, direct attribute access — no nested getattr chains) - Delete dummy_run special path entirely (synthetic 1-seq batch branch, sparse-attn placeholder branch, _v4_is_dummy_run helper, make_single_seq_plan fallback, indexer skip gate, compressor scatter dummy_run gate) - Delete _v4_get_seq_metadata helper + cpu_meta plumbing (all dead) - Delete slow path of _v4_build_sparse_inputs_batched (~263 LoC) and rename _v4_build_sparse_inputs_from_pack_meta -> _v4_build_sparse_inputs_batched - Delete slow path of _v4_gather_compressed_batched + dead n_committed_per_seq / k_per_block params - Indexer.forward_batched signature: drop cu_seqlens_q_cpu / start_pos_per_seq_cpu / win + dead k_per_seq_cpu return value - Indexer.__init__: cache _fp8_quant_func / _weights_scale (was rebuilt per CSA layer) - Promote V4_FORCE_UE8M0_QUANT / V4_USE_REF_QUANT / V4_AITER_HC_POST env-var reads to module-level constants - Promote `from aiter import QuantType as _AiterQuantType` to module level - Merge indexer.compressor.rotary_emb plumb into outer plumb (one less per-layer if-check) - Rename per-fwd locals for clarity: sp_per_seq_gpu -> start_pos_per_seq_gpu, cu_q_gpu -> cu_seqlens_q_gpu, sp_cpu -> start_pos_per_seq_cpu, etc. Removed APIs (unused after refactor): - make_single_seq_plan (atom/model_ops/v4_kernels/{__init__,compress_plan}.py) Verified: - Smoke `1+2+3=?` returns `**6**` - GSM8K-100 (ATOM_USE_TRITON_MOE=1, conc=16, fewshot=3): 0.96 +/- 0.020 * feat(deepseek_v4): CG-A pre-allocate metadata buffers (CUDAGraph prep) Replace ~25 per-fwd `torch.as_tensor(np_arr)` H2D allocations in V4 metadata builder with pre-allocated CpuGpuBuffer pool. Fixes GPU pointers across forwards — prerequisite for CUDAGraph capture (CG-B). Buffer pool allocated once in __init__ (~80 MB at typical config). All builder helpers now write via `_stage(name, arr)` which does `buf.np[:n] = arr; copy_to_gpu(n)` and asserts capacity. Coverage: - _attach_v4_per_fwd_meta: 4 buffers (start_pos / token_num / write_indices / state_slot) - _populate_state_slot_mapping: 1 buffer (groups) - _build_v4_indexer_meta: 6 buffers (batch_id / cu_committed / n_committed / k / offset / is_prefill) - _build_v4_gather_indices: 3 buffers x 3 callers (indexer / csa_dc / hca_dc) - _build_v4_pack_meta_for_ratio: 11 buffers per kind (csa/hca/dense) Forward path unchanged. Validated GSM8K-100 = 0.95 ± 0.022 (baseline 0.96). * feat(deepseek_v4): CG-B CUDAGraph capture infrastructure Prepares V4 backend for CUDAGraph capture/replay (still gated behind --enforce-eager removal in a follow-up). All capture-required GPU pointer addresses are now stable across forwards. Changes: - Kernels gain fixed-grid + sentinel-mask path: fused_compress_attn, update_compressor_states, swa_write all skip rows whose position == -1, so the wrapper can launch at full plan/buffer capacity (CUDAGraph capturable) regardless of how many tokens this fwd actually writes. - fused_compress_attn / update_compressor_states accept strided kv/score inputs (drop the defensive .contiguous() copies in callers); only inner column stride is required to be 1. - fused_compress_attn gains an out= param for caller-provided pre-allocated output buffer (used in CUDAGraph path to keep output address stable); eager path still allocates per-call. - make_compress_plans accepts plan_buffers dict of pre-allocated CpuGpuBuffer; writes into them and sentinel-fills tail rows. Empty-fwd path also fills buffers so capture-time addresses match replay. - DeepseekV4AttentionMetadataBuilder._alloc_v4_metadata_buffers pre-allocates v4_compress_plan_{ratio} / v4_write_plan_{ratio} CpuGpuBuffers and per-kind v4_{csa_main,csa_idx,hca_main}_compress_out BF16 tensors; build_kv_cache_tensor binds the latter to each Compressor module's compress_out attribute. - build_for_cudagraph_capture replaces the stub: synthesizes a decode batch at start_pos=window_size, runs through prepare_decode helpers (_attach_sparse_layout_metadata + _attach_v4_per_fwd_meta + _build_compress_plans), returns (AttentionMetaData, Context) wired to forward_vars buffers. - DeepseekV4Model.forward returns hidden_states (post hc_head + norm) instead of full vocab logits; DeepseekV4ForCausalLM.compute_logits applies head.get_logits. Required so the CUDAGraph output buffer is sized to hidden_size, not vocab_size (~18x smaller, also matches the ATOM standard contract used by other models). - Compressor gains compress_out attribute (set by builder; threaded through fused_compress_attn as out=). - kv_indptr stub buffer added to forward_vars (touched unconditionally by the global capture loop; V4 doesn't use it for its own kernels). Misc: - Hoist 3 lazy `from atom.model_ops.quant_v4 import act_quant_inplace as _v4_aqi` imports to the top-level import block. - Gate `act_quant_inplace(kv[..., :-rd], 64, scale_fmt)` on _V4_USE_REF_QUANT (default off). Previously unconditional; the env gate already exists for the matching qr/x quant pair, so making this consistent. GSM8K-100 = 0.99 with the gate (no regression vs prior unconditional path which also produced 0.99 in recent runs). Validation: GSM8K-100 = 0.99 ± 0.01 (eager mode). CUDAGraph end-to-end (without --enforce-eager) still pending — needs further capture-loop work. * refactor(deepseek_v4): linear fusions, MoE cleanup, shape contracts, perf nits Linear projection fusions (FP8/BF16, zero-copy split downstream): - attn.wq_a + attn.wkv → attn.wqkv_a (MergedReplicatedLinear, FP8) - compressor.wkv + compressor.wgate → compressor.wkv_gate (BF16, otype=fp32) - shared_experts.w1 + w3 → shared_experts.gate_up_proj (MergedColumnParallelLinear) - packed_modules_mapping routes disk shards via standard ATOM loader - Compressor and update_compressor_states accept strided kv/score inputs MoE refactor: - Drop use_fused/Gate/_torch_moe_forward/toy/dummy paths - Split forward into routed_expert_forward / combine_outputs / single_stream_moe_forward / dual_stream_moe_forward - Dispatch via torch.ops.aiter.maybe_dual_stream_forward (Dynamo barrier) - Extract maybe_dual_stream_forward into atom/model_ops/dual_stream_moe.py (shared with V2; V2 inline implementation removed) - Direct routed/shared dtype check for shared-expert fusion gating (V4 has FP4 routed + FP8 shared; the global-vs-shared helper returns the wrong answer because shared==global but routed!=global) Custom op fix: dual_stream_moe declares mutates_args=() (the V2-original mutates_args=["hidden_states"] is a false-mutation declaration — op returns a fresh tensor, never writes to input — and would mislead AOT/functionalization into inserting defensive clones). Aiter kernel refs hoisted: - _V4_AITER_HC_POST env gate removed; mhc_pre/mhc_post dim+presence check resolved once in Block.__init__ to self._mhc_pre / self._mhc_post - per-fwd path is just `if self._mhc_pre is not None:` attribute lookup Shape contracts (ATOM 2D-flat ragged-batch convention): - All forward signatures get inline shape annotations (e.g. `x: torch.Tensor, # [num_tokens, dim]`) - Drop legacy [B, S, ...] 4D paths in Block.hc_pre/hc_post, ParallelHead.hc_head, MTPBlock.forward, ParallelHead.get_logits - Drop input_ids.dim()==2 normalization in DeepseekV4Model.forward - Compressor.forward asserts 2D, drops defensive 3D-squeeze Code organization: - _segment_indices and _build_window_topk_batched moved from deepseek_v4.py to attentions/deepseek_v4_attn.py (only callers are the metadata builder); removes two cross-file lazy imports - _AiterQuantType alias removed (atom.config.QuantType is the same pybind class) - Stale # noqa: F401 pragmas dropped (sparse_attn_v4, v4_kernels imports are all actively referenced) - ruff full-pass on V4 + V2 + dual_stream_moe + V4 attn Indexer.forward_batched post-topk path: - 10 GPU launches + 1 full_like alloc → 7 launches + 0 allocs - (topk_local < 0) | future_mask is equivalent to width_mask | future_mask (fp8_mqa_logits masks out-of-seq logits to -inf, so topk_local < 0 only fires on width-masked slots) - masked_fill_ in-place over (topk_local + offset) replaces full_like + where Removed redundant ops in hot path: - vestigial unsqueeze(0)→squeeze(0) in Indexer.forward_batched, DeepseekV4Attention.forward, _v4_build_sparse_inputs_batched - .type_as(x) on aiter mhc_post path (out.dtype == residual.dtype == x.dtype) - unused `ratio = self.compress_ratio` local in Indexer.forward_batched Validation: GSM8K-100 num_fewshot=3 = 0.98 ± 0.014 (baseline 0.97 ± 0.017, within stderr). * feat(deepseek_v4): FP8 CSA Indexer cache (-44% pool VRAM) Convert v4_csa_idx_kv from BF16 to FP8+scale layout following V3.2 sparse_attn_indexer pattern. Pool size for the indexer cache drops 44% (BF16 1.07GB -> FP8+scale 0.55GB at NB=4096). Pool layout - shape: [n_csa, NB, k1_csa, aligned_dim=144] dtypes.fp8 (layer-major so pool[pos] is contig per CSA layer) - per row: [head_dim] FP8 + 4-byte fp32 scale, 16B-aligned Write path (Compressor.forward, idx_slot_mapping is not None) - Compressor gains optional idx_slot_mapping (int64). When set, the fused-compress kernel skips its BF16 scatter and we instead call indexer_k_quant_and_cache(out, kv_cache, slot_mapping, head_dim, scale_fmt) to FP8-quantize+write each compress row in one shot. - Slot mapping built host-side in _build_indexer_compress_slot_mapping from csa_compress_plan_cpu + block_tables (no extra GPU->CPU copy thanks to the new compress_plan_cpu field on CompressPlan). Read path (Indexer.forward_batched) - cp_gather_indexer_k_quant_cache(kv_cache, k_fp8, k_scale.view(fp8), block_tables, cu_committed_gpu) does paged-gather + split into separate (FP8, scale) buffers in one launch -- no per-row index list, no online quant. - Then fp8_mqa_logits over [Q_fp8, K_fp8, kv_scales=k_scale, weights] drops the legacy gather_compressed + BF16 einsum path entirely. Builder side - _build_v4_indexer_meta gains csa_compress_plan_cpu param; produces compress_slot_mapping_gpu (int64, kernel sig is int64_t*) and cu_committed_gpu (int32, kernel sig is int32_t*). - "indexer" gather buffer set removed -- cp_gather_indexer_k_quant_cache consumes block_tables + cu_seq_lens directly. - CompressPlan grows compress_plan_cpu: np.ndarray | None for the same reason: builder needs the plan rows host-side to derive slot_mapping without an extra D2H sync. Shape contract gotcha (root cause of an OOM-fault hunt) - Indexer.kv_cache binding MUST keep [NB, k1_csa, aligned_dim] (3D, block_size dim explicit). Flattening to [NB*k1, 1, aligned_dim] makes cp_gather_indexer_k_quant_cache infer block_size=1 from shape[1], which then OOB-indexes block_table. Matches V3.2's [num_blocks, block_size, head_dim] layout (deepseek_v2.py:1049). - Write side (indexer_k_quant_and_cache) is shape-agnostic -- uses slot_mapping flat index -- so the symmetric 3D binding for the inner Compressor is for clarity, not correctness. Validation - simple_inference V4-Pro tp=8 fp8 enforce-eager: all 4 prompts produce coherent output (1+2+3=**6**, prime list, Chinese long-form). - GSM8K-100 num_fewshot=3: flexible-extract / strict-match both 0.96 +/- 0.0197 (baseline 0.97 +/- 0.017, within tolerance). * feat(deepseek_v4): CG-friendly indexer Phase A — preshuffle + decode→deepgemm Three changes folded into one commit (validated together GSM8K-100=0.97 ± 0.0171, baseline 8ab1367b also 0.97): 1. **preshuffle on indexer write+read** (`indexer_k_quant_and_cache` + `cp_gather_indexer_k_quant_cache`): MFMA 16x16 tile-aware FP8 cache layout, matches V3.2/PR #658 convention. Required by `deepgemm_fp8_paged_mqa_logits` for `KVBlockSize > 1`. 2. **split `Indexer.forward_batched` into prefill/decode helpers**: common path (Q proj+RoPE+rotate+FP8 quant, weights computation) stays in `forward_batched`; dispatch via `context.is_prefill` to `_score_topk_prefill` (cp_gather + fp8_mqa_logits, eager-only — variable `total_committed` shape) or `_score_topk_decode` (deepgemm, fixed-shape `[bs*next_n, max_model_len_idx]`). Mixed batches go through prefill path. `_post_process_topk` shared, branches on `is_decode` to skip the seq_base subtraction (decode topk indices are already seq-local; prefill indices are global flat positions across cu_committed). 3. **decode helper uses `deepgemm_fp8_paged_mqa_logits`**: reads paged FP8 cache directly via 4D view `[NB, k1_csa=32, 1, aligned_dim=144]`, writes into pre-`-inf`-filled logits buffer (cols beyond per-seq context_lens stay -inf so PyTorch topk doesn't pick garbage). `width_mask` masked_fill handles per-token k_per_token trimming. CUDAGraph-friendly shapes — for Phase B/C buffer pre-allocation + capture path. Builder: expose `n_committed_per_seq_gpu` (int64, [bs]) in indexer_meta — no new H2D, just lifts the existing staged tensor into the return dict for deepgemm context_lens consumption. Init-time hoist: `Indexer._max_model_len_idx = args.max_seq_len // compress_ratio` — deepgemm output column count, constant per layer. Composition validated standalone (test_decode_deepgemm_vs_fp8_mqa.py: 100% top-K overlap with `cp_gather + fp8_mqa_logits` baseline given `-inf`-init buffer). Numerical round-trip with cache_stride=144 + preshuffle validated (test_indexer_roundtrip_numerical.py: cos≥0.9995 across all num_tokens / dispatch branches). Net: +119 / -20 LoC. Phase B/C (decode logits buffer pre-alloc + build_for_cudagraph_capture) tracked separately. * feat(deepseek_v4): adopt aiter top_k_per_row in indexer prefill+decode Replaces the torch.topk + -inf fill path in `Indexer._score_topk_*` with aiter `top_k_per_row_decode/prefill` (radix kernel, parametric k). Both paths emit a uniform [total_tokens, index_topk] int32 layout. _score_topk_decode (CG-friendly path): - Pre-allocated [max_bs, index_topk] int32 indices buffer in builder. - Pre-allocated [max_bs, max_model_len_idx] fp32 logits buffer. - Drop `fill_(-inf)`: top_k_per_row_decode honors n_committed_per_seq per row, so logits cells past valid range are never read. - Drop torch.topk + .to(int32) cast. _score_topk_prefill (eager-only path): - Drop torch.topk + dynamic-`max_k` shape; emit [total_tokens, index_topk] via top_k_per_row_prefill(k=index_topk), kernel writes -1 sentinels in tail cols. - Per-fwd torch.empty for indices (prefill total_tokens dynamic). Builder _build_v4_indexer_meta: - v4_indexer_n_committed_per_seq buffer i64 -> i32 (kernel arg dtype). - Add v4_indexer_decode_logits and v4_indexer_decode_topk_indices forward_vars buffers. - width_mask collapses to uniform [total_tokens, index_topk] bool. - Drop max_k from returned dict; empty-batch guard now keys on total_committed == 0. Builder _build_v4_pack_meta_for_ratio: - compress_topk_src stride is `index_topk` for both paths (was the dynamic max_k = max(k_per_seq), which assumed prefill's torch.topk(max_k) output shape). _post_process_topk: - Input contract changes to [total_tokens, index_topk] uniform layout. Depends on ROCm/aiter#3012 (exposes `k` kwarg on top_k_per_row_decode / top_k_per_row_prefill); existing aiter without that PR will silently ignore the kwarg and run with k=2048 (still correct, but allocates an oversized output buffer). Validation: - aiter kernel parity at v4 shapes (k=1024, varying bs/ctx) - all OK. - GSM8K-100 num_fewshot=3 eager: 0.97 / 0.97 (stable vs 0.96 baseline). * feat(deepseek_v4): CUDAGraph-friendly sparse decode via unified KV pool Enable CUDAGraph capture for DeepSeek-V4 (Pro / non-Pro) sparse decode. Final config validated: cudagraph-capture-sizes [1,2,4,8,16,32,64] + max-num-seqs 64, GSM8K-50 = 0.98. == Approach == Upstream V4 reference materializes "indexer-selected K's" into a per-fwd dense `kv_flat_sa` tensor whose shape depends on device-side data — this prevents CUDAGraph capture. ATOM replaces it with a paged interface (single base pointer + packed-cumsum kv_indptr + kv_indices) backed by per-layer unified BF16 pool, plus a dedicated triton kernel that handles V4-specific attn_sink + page_size=1. == Components == 1. New triton kernel `sparse_attn_v4_paged_decode` (atom/model_ops/v4_kernels/paged_decode.py): page_size=1 sparse attention with attn_sink, API aligned with V3.2 mla_decode_fwd naming. 3 unit tests bit-exact vs reference. 2. Per-layer `unified_kv` pool (Phase A, atom/model_ops/attentions/deepseek_v4_attn.py allocator): physically merges SWA ring buffer and compressor paged KV into one contiguous BF16 tensor — kernel uses one base pointer, every index (SWA / CSA / HCA) is a row offset. 3. Per-fwd paged-decode index construction (Phase B, `_attach_v4_paged_decode_meta`): builds 3 kv_indptr cumsums (SWA uniform stride, CSA / HCA packed) + scatters SWA window prefix + fully populates HCA compress section. All …
Liang-jianhao97
pushed a commit
that referenced
this pull request
May 7, 2026
* feat(top_k_per_row): expose parametric k via runtime arg
The radix backend (`standalone_stable_radix_11bits`) supports arbitrary k
at runtime, but the public C++/Python wrappers hardcoded `kTopK = 2048`.
Plumb `k` through end-to-end so callers (e.g. DeepSeek-V4 indexer with
`k = index_topk = 1024`) can use the same kernel without forcing 2048.
Changes:
- csrc/include/topk_per_row.h: add `int64_t k = 2048` default arg to
both `top_k_per_row_prefill` and `top_k_per_row_decode` (backward
compatible).
- csrc/include/rocm_ops.hpp: pybind expose `py::arg(\"k\") = 2048`.
- csrc/kernels/topk_per_row_kernels.cu:
* `invokeComputeTopkLastDimWorkspaceSize` accepts `int k_param =
2048` and threads it into the workspace-size calc + radix kernel.
* `top_k_per_row_prefill/decode` bodies replace
`static constexpr int kTopK = 2048` with `static_cast<int>(k)`.
* Explicit `<float>` instantiation updated to match new sig.
- aiter/ops/topk.py: add `k: int = 2048` to the regular Python sigs
(`top_k_per_row_prefill`, `top_k_per_row_decode`) so the
`@compile_ops`-generated torch op schema accepts the new kwarg. The
`_fast` ASM-kernel variants are left untouched — their precompiled
`.co` blobs hardcode k=2048 and cannot honor a runtime k.
- op_tests/test_topk_per_row.py:
* Default `--top_k` widened from `[2048]` to `[512, 1024, 2048]`
so CI exercises the parametric path.
* `run_top_k_per_row_{prefill,decode}` thread `k` through.
* `_fast` decode is skipped when k != 2048 (kernel is hardcoded
and would write 2048 ints into a smaller buffer).
* `run_top_k_per_row_decode` asserts `k == 2048` for `_fast`.
Backward compatibility: every caller without a `k` arg continues to get
k=2048, matching previous behavior. No existing test changes signature.
* fix(top_k_per_row): sync k_param across topk_plain TU
Commit 4027229 added `int k_param = 2048` to
`invokeComputeTopkLastDimWorkspaceSize` in
`csrc/kernels/topk_per_row_kernels.cu` (definition + explicit
instantiation), but missed the forward declaration / extern template in
`csrc/kernels/topk_plain_kernels.cu`. Both TUs link into
`module_topk_plain.so`, so topk_plain emitted an undefined reference to
the old 2-arg symbol while topk_per_row only exported the 3-arg one:
ImportError: module_topk_plain.so: undefined symbol:
_Z37invokeComputeTopkLastDimWorkspaceSizeIfLN5aiter5PhaseE0EElii
-> long invokeComputeTopkLastDimWorkspaceSize<float, (aiter::Phase)0>(int, int)
Add the matching `int k_param = 2048` to the forward decl and extern
template so the call site at line 2499 mangles to the 3-arg symbol that
the explicit instantiation provides.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The radix backend (
standalone_stable_radix_11bits) supports arbitrarykat runtime, but the public C++/Python wrapperstop_k_per_row_prefillandtop_k_per_row_decodehardcodedkTopK = 2048. This PR plumbskthrough end-to-end so callers (e.g. DeepSeek-V4 indexer withk = index_topk = 1024) can use the same kernel without forcing 2048.Backward compatible: the default
k = 2048matches previous behavior; existing callers do not need changes.Changes
C++ surface (header / pybind / kernel)
csrc/include/topk_per_row.h— both wrapper signatures getint64_t k = 2048default.csrc/include/rocm_ops.hpp— pybind exposespy::arg(""k"") = 2048for both ops.csrc/kernels/topk_per_row_kernels.cuinvokeComputeTopkLastDimWorkspaceSize<T, Phase>(numRows, stride0)gainsint k_param = 2048; the internalconstexpr int k = 2048becomes runtimeint k = k_paramand is threaded into the radix call.<float>instantiation updated.top_k_per_row_prefill/top_k_per_row_decodebodies replacestatic constexpr int kTopK = 2048withstatic_cast<int>(k)and pass it into both the workspace-size query andstandalone_stable_radix_11bits.Python surface
aiter/ops/topk.py—top_k_per_row_prefillandtop_k_per_row_decode(the radix-backed variants) gaink: int = 2048so the@compile_ops-generated torch op schema accepts the new kwarg._fastASM-kernel variants (top_k_per_row_prefill_fast,top_k_per_row_decode_fast) are intentionally not modified — their precompiled.coblobs hardcode k=2048 and cannot honor a runtime k.Tests
op_tests/test_topk_per_row.py--top_kwidened from[2048]to[512, 1024, 2048].run_top_k_per_row_{prefill,decode}threadkthrough to the kernel call._fastdecode is skipped whenk != 2048; the dispatcher also assertsk == 2048for_fastto surface misuse loudly.Local validation
Ran
python op_tests/test_topk_per_row.pyon MI355X (gfx950) with the new defaults:Prefill (9 cases, k ∈ {512, 1024, 2048} × ctx ∈ {1024, 4096, 16384}): all
all_close = True.Decode (18 cases, bs ∈ {4, 16} × ctx ∈ {1024, 4096, 16384} × k ∈ {512, 1024, 2048}, next_n=1): all
all_close = True.Sample decode results:
Test plan
python op_tests/test_topk_per_row.pyparity vstorch.topkfor k ∈ {512, 1024, 2048}.op_tests/test_topk_per_row.pyruns on the target gfx variants.