[TRITON] Conv Kernels First Commit to AITER#2886
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
| padding=(0, 0), | ||
| dilation=(1, 1), | ||
| activation="none", | ||
| out_dtype=torch.float16, |
There was a problem hiding this comment.
We can have a "None" default maybe?
If none, the out dtype can be same as input dtype?
|
Others can chime in here but typically we dont do auto tuning in the runtime. It adds certain randomness and can behave unexpectedly when part of cuda graph. Same inputs can lead to different kernel config being picked if time delta between them is relatively small. However, assuming it gets handled correctly by the kernel consumer, I dont see much of a problem given the config size is small (<8 configs per kernel). I assume the config list is selected for RDNA4 based on what the README contains? Maybe we can note that in the helper. Also, we have certain artifacts like README.md and DESIGN.md that also come with pngs embedded in them. I dont have strong opinions about where to store that info. |
Hi @cagrikymk. To be honest with you, I didn't review this PR in depth, but I see some red flags at first glance:
|
|
Added some labels to this PR, to be sure CI will run on all supported CDNA architectures. |
|
Thanks for taking the time to review this PR. Both points addressed in the latest revision: On runtime autotune — I'd appreciate a bit more guidance here. Looking at the existing Triton kernels in AITER, several already use @triton.autotune for runtime config selection, so the conv kernels followed that pattern. |
Thanks for helping to review this PR. |
|
Regarding Triton auto tuning:
Regarding tests:
There's a conv regorg going on in #3048, please take a look at this PR and try to follow a similar struture. |
Adds a Triton conv2d library targeted at AMD RDNA GPUs, plus a
correctness + benchmark harness that compares against PyTorch / MIOpen.
Motivation
PyTorch on AMD goes through MIOpen, whose hand-tuned solvers cover some
dtype × layout × architecture combinations well and others poorly — bf16 in
particular falls back to direct/GEMM solvers on RDNA4 that are noticeably
slower at large channel counts. Most modern checkpoints (LLMs, diffusion VAEs)
ship in bf16, so the gap matters.
This op takes the opposite approach: a single set of Triton kernels that runs
fp16 and bf16 through the same code path, supports NCHW and NHWC, and gets reasonable performance across
the full matrix without per-architecture hand tuning.
What's added
Library (
aiter/ops/triton/conv/):conv2d.py— public API + shape-driven router_launch.py— grid setup +_select_3x3_methodheuristic_prepack.py— LRU-cached weight/input repack_utils.py— shape math, dtype/activation enums, tolerance modelKernels (
aiter/ops/triton/_triton_kernels/conv/), five families:R==1, S==1C ≥ 512,K ≥ 512, enough output tilesTest/bench harness (
op_tests/triton_tests/conv/):cli.py—--test-mode {edge,random,stability,activations,models,all}suite.py— correctness checking + bench accumulation + result tablesbench.py— timing +precompute_miopen_solvers(subprocess +MIOPEN_LOG_LEVEL=6to label each PyTorch baseline row with the MIOpen solver it picked)
test_edge.py/test_fuzz.py/test_models.py— shape sourcestest_pytest.py— parametrized over fp16/bf16 × nchw/nhwc_registry.py— single source of truth for kernel methods (used by CLI,suite, comparison tables, tolerance dispatch)
Bench shim (
op_tests/op_benchmarks/triton/bench_conv2d.py) — convenienceentry that injects
--benchmark --test-mode models.Docs:
aiter/ops/triton/conv/README.md— quick start, headline results, constraints,reproducing instructions
aiter/ops/triton/conv/DESIGN.md— architecture, per-kernel deep-dive, fullWinograd F(4,3) derivation (G/Bᵀ/Aᵀ matrices, 361× amplification analysis,
why Winograd is disabled for
C < 4), the routing heuristic, memory layouts,numerical model, extension guide
Performance
See
aiter/ops/triton/conv/README.md#headline-resultsfor the full chart set (resnet50 / SD3.5 VAE / FLUX.2 VAE × fp16/bf16 ×
nchw/nhwc × multiple batch sizes, on RDNA4).
Constraints
groupsmust equal 1 — depthwise / grouped not yet implemented.Test harness skips grouped layers and prints a banner showing how many were
skipped (so coverage % is visible).
padding_modemust be"zeros". Pad amount is unrestricted; only padvalue —
"reflect","replicate","circular"are out of scope.fp16orbf16.Testing
All run on ROCm 7.2 / PyTorch 2.9.1 / Triton 3.7 (commit
23f4e522d).cli --test-mode all --layout both --dtype fp16cli --test-mode all --layout both --dtype bf16pytest test_pytest.py -k test_no_bias(× fp16/bf16 × nchw/nhwc)bench_conv2d --model-name resnet50 --num-layers 5 --layout both(fp16, bf16)Per-method correctness: each kernel family is exercised across 12 edge-case
shapes, 200 random shapes, 4 fused activations (none/relu/relu6/gelu), and the
real per-layer shapes captured by hooking ResNet-50 / SD3.5 VAE / FLUX.2 VAE
forwards.
How to use