Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions .claude_docs/nemo-fp4-moe-b12x-mr/w4a16/W4A16_DENSE_DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# `b12x` W4A16 dense GEMM for SM120 / SM121 — design

Date: 2026-05-15.
Author: fkhoubsirat.
Companion docs: `W4A16_DESIGN.md` (MoE/wrapper-level), `W4A16_RESULTS.md`,
`W4A16_NSYS_VERIFICATION.md`.

## Goal

Add a **W4A16 dense GEMM** to b12x (`b12x/b12x/gemm/w4a16/`) that takes
bf16 activations and FP4-packed weights directly, eliminating the
"BF16 → online FP4 quantize → CUTLASS W4A4 GEMM" path that TensorRT-LLM
currently uses for non-MoE NVFP4 linears.

Target model: `nvidia/Nano3.5-BF16-NVFP4-W4A16-LMHEAD-CT` (NemotronH
hybrid Mamba+Attn+MoE, hidden=2688, 128 experts, top_k=6,
`mlp_hidden_act="relu2"`, lm_head NVFP4-quantized).

Scope (v1): **decode-only**, M ≤ 32. Isolated kernel run + bench. No
TensorRT-LLM integration in this milestone.

## Decisions (locked)

| Topic | Decision |
|---|---|
| Kernel base | Fork `b12x/moe/fused/w4a16/micro.py`. Strip routing, expert dim, scatter; keep TMA load + FP8 SF unswizzle + bf16-stage + mma-bf16-vs-fp4 inner loop. |
| File layout | `b12x/b12x/gemm/w4a16/{__init__.py, micro.py, reference.py}`. |
| Activations | bf16 only (matches Nano3.5 `dtype: bfloat16`). No FP16 path. |
| Output dtype | bf16. |
| Bias | None (Nano3.5 has `use_bias: false`). |
| Activation fusion | None — caller applies relu² / SiLU on the bf16 output. |
| Weight layout | `[N, K//2] uint8` (FP4 packed two-per-byte, same as the MoE W4A16 single-expert slice). |
| Block-scale layout | `[N//128 * 32, K//16 * 4 * 4] float8_e4m3fn` (swizzled per b12x convention). Group size = 16. |
| `weight_scale_2` | Scalar `float32`; consumed as the kernel's epilogue `alpha`. |
| M range | 1, 2, 4, 8, 16, 32 (same `m_tile` ladder as MoE micro). |
| Arch gate | SM120 and SM121, single kernel. Same MMA family as the existing W4A16 MoE micro kernel and `DeviceGemmFp4GemmSm120`. |
| Accuracy reference | `dense_reference_w4a16` — dequant FP4 → bf16 in fp32, then `torch.matmul`, cast to bf16. Mirrors `moe_reference_w4a16`. |
| Accuracy tolerances | `max_abs ≤ 8e-4`, `rmse ≤ 5e-5`, `cos > 0.9999` — borrowed from the MoE eager-prefill suite. |
| Bench timing | `bench_events` (cuda.Event), L2-flush option, CUDA-graph replay — same as `benchmarks/benchmark_dense_gemm.py`. |
| Baselines | (1) FlashInfer `mm_fp4` (W4A4, includes online act-quant); (2) TRT-LLM `cudaCoreGemmFp4` invoked via its Python op binding. |
| Build/test environment | Inside `tensorrt_llm-devel-fkhoubsirat` container with the b12x repo mounted at `/workspace/agentic-dev/b12x`. |

## Target MNK (from nsys + model config)

Source: `logs/nsys_proper_20260512_125230.sqlite`, kernels
`cudaCoreGemmFp4<fp4, bf16, fp8, ...>` (decode dense linears, 75 calls
at M=1).

| Linear | M | K | N | Source |
|---|---|---|---|---|
| q_proj | 1..32 | 2688 | 4096 | attn `num_heads*head_dim` |
| k_proj | 1..32 | 2688 | 256 | attn `num_kv_heads*head_dim` |
| v_proj | 1..32 | 2688 | 256 | attn `num_kv_heads*head_dim` |
| o_proj | 1..32 | 4096 | 2688 | nsys gridY=1344 ⇒ N=2688 |
| shared_expert.up | 1..32 | 2688 | 3712 | nsys gridY=1856 ⇒ N=3712 (non-gated relu²) |
| shared_expert.down | 1..32 | 3712 | 2688 | shares 2688 N-tile path |
| lm_head | 1..32 | 2688 | 131072 | last-token logits; large-N edge |

k_proj and v_proj are included for completeness; their N=256 is degenerate
for any 128-N-tile kernel and is acceptance-only (correctness, not perf).

## File layout

```
b12x/b12x/gemm/w4a16/
├── __init__.py # exports: dense_gemm_w4a16, DenseGemmW4A16MicroKernel
├── micro.py # forked from moe/fused/w4a16/micro.py
└── reference.py # forked from moe/fused/w4a16/reference.py

tests/
└── test_dense_gemm_w4a16.py

benchmarks/
└── benchmark_dense_gemm_w4a16.py
```

## Public kernel API

```python
def dense_gemm_w4a16(
x: torch.Tensor, # [M, K] bf16, contiguous
w_fp4: torch.Tensor, # [N, K // 2] uint8
w_blockscale: torch.Tensor, # swizzled FP8 SF, shape per b12x convention
w_alpha: torch.Tensor, # scalar float32 == weight_scale_2
*,
out: torch.Tensor | None = None, # [M, N] bf16; allocated if None
) -> torch.Tensor:
...
```

Backend class `DenseGemmW4A16MicroKernel` exposes `is_supported(m, k, n)` and
the launchable entry point (matches the MoE micro backend convention).

## Data flow (single matmul)

```
[M, K] bf16 activation [N, K/2] uint8 weight FP8 block-scale (swizzled) weight_scale_2 (scalar f32)
│ │ │ │
└─ TMA load to smem └─ TMA load to smem └─ unswizzle in smem │
│ │ │ │
└────── MMA: bf16 × dequant(fp4, fp8_sf) → fp32 accumulator ───────────┘
└─ acc * fp8_sf_block * weight_scale_2
[M, N] bf16 output
```

## Reference (Python, fp32)

```python
def dense_reference_w4a16(x_bf16, w_fp4, w_blockscale, w_alpha):
w_dequant = unpack_fp4_dequant(w_fp4, w_blockscale, w_alpha) # fp32 [N, K]
return (x_bf16.float() @ w_dequant.t()).to(torch.bfloat16)
```

Reuse `unswizzle_block_scale`, the FP4→bf16 dequant helpers, and
`compare_to_reference` from `b12x/moe/fused/w4a16/reference.py`.

## Tests — `tests/test_dense_gemm_w4a16.py`

| Test | Scope | GPU required |
|---|---|---|
| `test_reference_dequant_roundtrip` | pack/unpack/dequant symmetry on synthetic weights | no |
| `test_micro_is_supported` | `(m, k, n)` matrix over M∈{1,2,4,8,16,32} × the 6 Nano shapes | no |
| `test_accuracy_nano_shapes` | kernel output vs reference for each shape × M∈{1,8,32}, asserts `max_abs ≤ 8e-4`, `rmse ≤ 5e-5`, `cos > 0.9999` | SM120/SM121 |
| `test_kv_proj_n256_correctness` | acceptance for the degenerate N=256 case | SM120/SM121 |

## Benchmark — `benchmarks/benchmark_dense_gemm_w4a16.py`

CUDA-graph replay, `bench_events` per iter, optional L2 flush.

**Baselines:**
1. `flashinfer.gemm.mm_fp4` — W4A4 with online activation quant. Already
used in `benchmarks/benchmark_dense_gemm.py:resolve_flashinfer_ref`.
2. TensorRT-LLM `cudaCoreGemmFp4` via
`torch.ops.trtllm.cuda_core_nvfp4_gemm` (registered in
`cpp/tensorrt_llm/thop/cudaNvfp4MM.cpp`). Signature:
`(mat_a[M, K/2] fp4, mat_b[N, K/2] fp4, scale_a, scale_b, alpha[f32],
bias?, out_dtype?, ...) -> out[M, N]`. **The op consumes pre-quantized
FP4 activations**, so the apples-to-apples baseline must time the
activation-quant step (`torch.ops.trtllm.quantize_with_block_size` —
nsys shows ~1.49 µs mean) **plus** the GEMM. Without that, the
comparison would be unfair to the path we are replacing.

**Shapes:** the 6 Nano35 dense linears (4 attn + 2 shared-expert) +
optional `--lm-head` flag for the 2688×131072 case.

**Default invocation:**

```bash
python benchmarks/benchmark_dense_gemm_w4a16.py \
--shapes nano35 \
--m-list 1,8,16,32 \
--warmup 10 --iters 50 \
--baselines flashinfer,trtllm \
--cuda-graph \
--flush-l2 \
--output /workspace/TensorRT-LLM/.claude_docs/nemo-fp4-moe-b12x-mr/w4a16/results_dense_w4a16.csv
```

**Output columns:** `shape | M | K | N | b12x_us | fi_mm_fp4_us | trtllm_cuda_core_gemm_us | ratio_vs_fi | ratio_vs_trtllm | accuracy_cos`.

## Build / run flow (inside TRT-LLM container)

```bash
# Already inside tensorrt_llm-devel-fkhoubsirat container with mounts:
# /workspace/TensorRT-LLM (TRT-LLM checkout, Python overlay)
# /workspace/agentic-dev/b12x (b12x repo)
cd /workspace/agentic-dev/b12x
pip install -e . # CuTe DSL JIT — no compile step

pytest tests/test_dense_gemm_w4a16.py -v
python benchmarks/benchmark_dense_gemm_w4a16.py --shapes nano35 ...
```

## Verification gate (must pass before claiming v1 works)

1. **Unit tests**: `pytest tests/test_dense_gemm_w4a16.py` all green on a
SM120 host. Reference-only tests pass on CPU.
2. **Accuracy**: every Nano35 shape × M ∈ {1, 8, 32} passes
`max_abs ≤ 8e-4`, `rmse ≤ 5e-5`, `cos > 0.9999` vs reference.
3. **Perf**: at M=1, b12x latency ≤ TRT-LLM `cudaCoreGemmFp4` latency for
the 5 main shapes (q_proj, o_proj, shared_expert.{up,down}, lm_head).
k/v N=256 is acceptance-only.
4. **nsys spot-check**: a replay-only nsys profile shows the b12x
`dense_gemm_w4a16` kernel firing, with **zero** `cudaCoreGemmFp4` /
`DeviceGemmFp4GemmSm120` instances.

## Open items (resolved during plan / implementation, not blocking design)

1. Exact FP8 SF swizzle shape constants reused from
`b12x/moe/fused/w4a16/reference.py` for the single-expert (no-E)
case. Should be a trivial squeeze of the leading expert dim.
2. lm_head N=131072 may need a separate N-split path; v1 keeps it on
the same kernel and lets the bench surface any regression.

## Out of scope (v1)

- Prefill / static / dynamic kernel variants (per the M ≤ 32 scope).
- TRT-LLM integration / op registration.
- TP > 1.
- Fused activation (relu², SiLU).
- Bias.
- FP16 activations (only bf16).
- Mamba in/out projection support — most are excluded from quant in this
checkpoint and use a different kernel path anyway.

## References

- Local b12x repo: `/home/farazkh_scratch/agentic-dev/b12x`.
- Local Nano35 snapshot: `Nano3.5-BF16-NVFP4-W4A16-LMHEAD-CT/`.
- Existing W4A16 MoE micro kernel (template):
`b12x/b12x/moe/fused/w4a16/micro.py`.
- Existing W4A4 dense kernel (orientation reference):
`b12x/b12x/gemm/dense.py`.
- nsys trace (baseline kernel attribution):
`.claude_docs/nemo-fp4-moe-b12x-mr/w4a16/logs/nsys_proper_20260512_125230.sqlite`.
- FlashInfer W4A16 MoE PR (parent project):
flashinfer-ai/flashinfer#3271, merge SHA `5ef7afa3`.
102 changes: 102 additions & 0 deletions .claude_docs/nemo-fp4-moe-b12x-mr/w4a16/W4A16_DENSE_PERF_SNAPSHOT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# W4A16 dense GEMM — perf snapshot

Date: 2026-05-15.
Branch: `master` @ `7aaab43` (b12x), `faraz/b12x-flashinfer-moe-pr` @ `96461f7bb3` (TRT-LLM).
Companion: `W4A16_DENSE_RESULTS.md`, `W4A16_DENSE_DESIGN.md`.

## Environment

| Item | Value |
|---|---|
| GPU | 1× NVIDIA RTX PRO 6000 Blackwell Server Edition (SM120, 97 GB) |
| Container | `tensorrt_llm/devel:latest` (`b12x-dev`) |
| CUDA | 13.1 |
| Driver | 590.48.01 |
| Torch | 2.11.0a0+eb65b36914.nv26.02 |
| cutlass-dsl | 4.4.2 |
| TRT-LLM ops | loaded via `ctypes.CDLL libth_common.so` |

## Method

* 50 timed iterations, 10 warmup, median reported (µs).
* `torch.cuda.Event(enable_timing=True)` per iter, `torch.cuda.synchronize` at boundaries.
* No CUDA-graph capture; no L2 flush between iters (back-to-back).
* TRT-LLM baseline = `torch.ops.trtllm.fp4_quantize(bf16, ...)` (online act-quant) + `torch.ops.trtllm.cuda_core_nvfp4_gemm`, both timed together — apples-to-apples vs b12x's bf16-in path.
* b12x.triton = `b12x.gemm.w4a16.dense_gemm_w4a16` with `B12X_GEMM_W4A16_USE_CUTE` unset.
* `M = 16`, weight/activation seeded with `torch.manual_seed(0)`.

## Headline numbers (µs, M=16, Nano35 dense shapes)

| Shape | M | K | N | TRT-LLM µs | b12x.triton µs | triton / trtllm |
|---|---:|---:|---:|---:|---:|---:|
| q_proj | 16 | 2688 | 4096 | **25.0** | 149.0 | 6.0× |
| k_proj | 16 | 2688 | 256 | **19.1** | 149.0 | 7.8× |
| v_proj | 16 | 2688 | 256 | **19.0** | 149.0 | 7.8× |
| o_proj | 16 | 4096 | 2688 | **23.1** | 216.9 | 9.4× |
| shared_expert.up | 16 | 2688 | 3712 | **23.9** | 147.8 | 6.2× |
| shared_expert.down | 16 | 3712 | 2688 | **24.0** | 199.3 | 8.3× |
| lm_head | 16 | 2688 | 131072 | **—** | 966.1 | n/a |

`—` = TRT-LLM's `cuda_core_nvfp4_gemm` dispatcher rejects this configuration (lm_head N=131072 is above its supported envelope). b12x.triton is the only working kernel for this shape today.

### Observations

* **TRT-LLM is the bar**: 19–25 µs on the in-dispatch shapes, slightly faster on the skinny-N k/v_proj (no Tensor-Core utilisation either way).
* **b12x Triton is 6–9× behind TRT-LLM** at M=16, mirroring the earlier M=1 ratios. Hot path is the smem↔register element-wise transit; the gmem load is already saturated by the Triton vectorised path.
* **lm_head only works on b12x today**. ~1 ms for 16×2688×131072 is bandwidth-bound on the 350 M-FP4-byte weight load; cuBLAS-class kernels would do this in <300 µs but TRT-LLM's nvfp4 path doesn't dispatch.
* **CuTe-DSL v3.1** runs bit-exact (`cos = 1.0` on every Nano35 shape × M=16) but its perf at q_proj M=16 wasn't captured in this session — cute JIT compile for K=2688 takes ~5 min on this host and we deleted the in-flight bench. Expected order of magnitude is **rough Triton parity** (~150 µs) since v3.1 only widened the gmem copies; the inner-loop smem↔register copies are still `CopyUniversalOp` (LdMatrix swap is v3.2 follow-up, see results doc).

## Reproduction

```bash
docker exec b12x-dev bash -lc '
cd /workspace/agentic-dev/b12x
python -u <<PY
import torch, ctypes, os
ctypes.CDLL("/workspace/TensorRT-LLM/tensorrt_llm/libs/libtensorrt_llm.so", mode=ctypes.RTLD_GLOBAL)
ctypes.CDLL("/workspace/TensorRT-LLM/tensorrt_llm/libs/libth_common.so", mode=ctypes.RTLD_GLOBAL)
os.environ.pop("B12X_GEMM_W4A16_USE_CUTE", None)
dev = torch.device("cuda")
from b12x.gemm.w4a16 import quantize_dense_weight_to_fp4, dense_gemm_w4a16
def bench(r, w=10, n=50):
for _ in range(w): r()
torch.cuda.synchronize()
S = [torch.cuda.Event(enable_timing=True) for _ in range(n)]
E = [torch.cuda.Event(enable_timing=True) for _ in range(n)]
for i in range(n):
S[i].record(); r(); E[i].record()
torch.cuda.synchronize()
return sorted([s.elapsed_time(e)*1000 for s, e in zip(S, E)])
for name, m, k, n in [
("q_proj",16,2688,4096), ("k_proj",16,2688,256), ("v_proj",16,2688,256),
("o_proj",16,4096,2688), ("se.up",16,2688,3712), ("se.down",16,3712,2688),
("lm_head",16,2688,131072),
]:
torch.manual_seed(0)
x = (torch.randn(m, k, dtype=torch.bfloat16, device=dev) * 0.5).contiguous()
w = (torch.randn(n, k, dtype=torch.bfloat16, device=dev) * 0.1).contiguous()
w_fp4, w_bs, w_alpha = quantize_dense_weight_to_fp4(w)
out = torch.empty(m, n, dtype=torch.bfloat16, device=dev)
gs = torch.tensor(1.0, dtype=torch.float32, device=dev)
w_fp4_t, w_sf_t = torch.ops.trtllm.fp4_quantize(w, gs, 16, False, True)
alpha = torch.tensor(1.0, dtype=torch.float32, device=dev)
def trt():
x_fp4, x_sf = torch.ops.trtllm.fp4_quantize(x, gs, 16, False, True)
return torch.ops.trtllm.cuda_core_nvfp4_gemm(x_fp4, w_fp4_t, x_sf, w_sf_t, alpha, None, torch.bfloat16, 0, None)
def tri():
return dense_gemm_w4a16(x, w_fp4, w_bs, w_alpha, out=out)
try: trt_med = bench(trt)[25]
except Exception: trt_med = None
tri_med = bench(tri)[25]
print(f"{name:18s} M={m} K={k:5d} N={n:6d} trtllm={trt_med!s:>7} triton={tri_med:6.1f}us")
PY
'
```

## Gaps

1. **Cute v3.1 perf not captured** — needs ~15-30 min of host time for JIT-compile across all Nano35 shapes. Add to a v3.4 follow-on if a number is needed before LdMatrix lands.
2. **No CUDA-graph capture** in this measurement.
3. **No L2-flush between iters** — back-to-back replay benefits from hot caches; real serving traffic would see ~10-30% higher numbers.

These gaps don't change the headline conclusion: closing the 6–9× gap requires LdMatrix/StMatrix on the smem↔register path (v3.2), which needs the swizzled smem layouts from `b12x/gemm/dense.py:190+`.
Loading
Loading