diff --git a/aiter/ops/triton/_triton_kernels/conv/__init__.py b/aiter/ops/triton/_triton_kernels/conv/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py b/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py
new file mode 100644
index 0000000000..d4ed10578f
--- /dev/null
+++ b/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py
@@ -0,0 +1,163 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import triton
+import triton.language as tl
+
+from .helpers import _tanh, AUTOTUNE_1x1_CONFIGS
+
+
+@triton.autotune(
+ configs=AUTOTUNE_1x1_CONFIGS,
+ key=["M_total", "K_out", "C"],
+ reset_to_zero=["Y"],
+ warmup=50,
+ rep=200,
+ cache_results=True,
+)
+@triton.jit
+def _conv2d_1x1_kernel(
+ X,
+ W,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ K_out: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ stride_h: tl.constexpr,
+ stride_w: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ M_total: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+ LAYOUT: tl.constexpr,
+):
+ """
+ Specialized 1x1 convolution kernel.
+ - No R*S loop (R=S=1)
+ - Direct channel reduction
+ - Simplified pointer arithmetic
+ LAYOUT: 0=NCHW, 1=NHWC
+ """
+ # W is always [K_out, C] contiguous
+ stride_w_k: tl.constexpr = C
+ stride_w_c: tl.constexpr = 1
+ if LAYOUT == 0:
+ # NCHW: X[N, C, H, W_in], Y[N, K_out, P, Q]
+ stride_x_n: tl.constexpr = C * H * W_in
+ stride_x_c: tl.constexpr = H * W_in
+ stride_x_h: tl.constexpr = W_in
+ stride_x_w: tl.constexpr = 1
+ stride_y_n: tl.constexpr = K_out * P * Q
+ stride_y_k: tl.constexpr = P * Q
+ stride_y_p: tl.constexpr = Q
+ stride_y_q: tl.constexpr = 1
+ else:
+ # NHWC: X[N, H, W_in, C], Y[N, P, Q, K_out]
+ stride_x_n: tl.constexpr = H * W_in * C
+ stride_x_c: tl.constexpr = 1
+ stride_x_h: tl.constexpr = W_in * C
+ stride_x_w: tl.constexpr = C
+ stride_y_n: tl.constexpr = P * Q * K_out
+ stride_y_k: tl.constexpr = 1
+ stride_y_p: tl.constexpr = Q * K_out
+ stride_y_q: tl.constexpr = K_out
+
+ pid = tl.program_id(axis=0)
+
+ # M = N * P * Q (output spatial), N_dim = K_out (output channels)
+ num_pid_m = tl.cdiv(M_total, BLOCK_M)
+ num_pid_n = tl.cdiv(K_out, BLOCK_N)
+
+ # L2 cache swizzle pattern
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m >= num_pid_m:
+ return
+
+ # Compute output tile indices
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ # Decode (n, p, q) from linear index
+ n_idx = offs_m // (P * Q)
+ pq = offs_m % (P * Q)
+ p_idx = pq // Q
+ q_idx = pq % Q
+
+ # Valid output mask
+ m_mask = offs_m < M_total
+ n_mask = offs_n < K_out
+
+ ih = p_idx * stride_h - pad_h
+ iw = q_idx * stride_w - pad_w
+
+ # Check spatial bounds
+ spatial_valid = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W_in) & (n_idx < N)
+
+ # Base pointers for this output tile
+ x_base = X + n_idx * stride_x_n + ih * stride_x_h + iw * stride_x_w # [BLOCK_M]
+ w_base = W + offs_n * stride_w_k # [BLOCK_N]
+
+ # Accumulator
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # Channel reduction loop
+ for k0 in range(0, C, BLOCK_K):
+ k_offs = k0 + offs_k
+ k_mask = k_offs < C
+
+ # Load input: X[n, c, ih, iw] -> shape [BLOCK_M, BLOCK_K]
+ x_ptrs = x_base[:, None] + k_offs[None, :] * stride_x_c
+ x_mask = spatial_valid[:, None] & k_mask[None, :]
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0)
+
+ # Load weight: W[k_out, c] -> shape [BLOCK_K, BLOCK_N]
+ w_ptrs = w_base[None, :] + k_offs[:, None] * stride_w_c
+ w_mask = k_mask[:, None] & n_mask[None, :]
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
+
+ # Accumulate: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N]
+ acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32)
+
+ # Bias
+ if HAS_BIAS:
+ b = tl.load(BIAS + offs_n, mask=n_mask, other=0.0)
+ acc += b[None, :]
+
+ # Activation
+ if ACT_TYPE == 1: # ReLU
+ acc = tl.maximum(acc, 0)
+ elif ACT_TYPE == 2: # ReLU6
+ acc = tl.minimum(tl.maximum(acc, 0), 6)
+ elif ACT_TYPE == 3: # GELU
+ acc = (
+ 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc)))
+ )
+
+ # Store output: Y[n, k, p, q]
+ y_ptrs = (
+ Y
+ + n_idx[:, None] * stride_y_n
+ + offs_n[None, :] * stride_y_k
+ + p_idx[:, None] * stride_y_p
+ + q_idx[:, None] * stride_y_q
+ )
+ y_mask = m_mask[:, None] & n_mask[None, :]
+ tl.store(y_ptrs, acc, mask=y_mask)
diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py b/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py
new file mode 100644
index 0000000000..650028c0c5
--- /dev/null
+++ b/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py
@@ -0,0 +1,312 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import triton
+import triton.language as tl
+from .helpers import (
+ _tanh,
+ AUTOTUNE_3x3_NHWC_CONFIGS,
+ AUTOTUNE_3x3_CBLOCKED_CONFIGS,
+)
+
+
+@triton.autotune(
+ configs=AUTOTUNE_3x3_NHWC_CONFIGS,
+ key=["M_total", "K_out", "C_pad"],
+ reset_to_zero=["Y"],
+ warmup=50,
+ rep=200,
+ cache_results=True,
+)
+@triton.jit
+def _conv2d_3x3_nhwc_kernel(
+ X,
+ W3,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ K_out: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ C_pad: tl.constexpr,
+ stride_h: tl.constexpr,
+ stride_w: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ dil_h: tl.constexpr,
+ dil_w: tl.constexpr,
+ M_total: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+):
+ """Specialized 3x3 NHWC kernel: stride_x_c=1 and stride_y_k=1 hardcoded
+ so the compiler can emit coalesced vector loads/stores."""
+ # X layout: [N, H, W_in, C] contiguous NHWC (stride_x_c=1 hardcoded in load logic)
+ stride_x_w: tl.constexpr = C
+ stride_x_h: tl.constexpr = W_in * C
+ stride_x_n: tl.constexpr = H * W_in * C
+ # W3 layout: [K_out, 9, C_pad] contiguous
+ stride_w3_c: tl.constexpr = 1
+ stride_w3_rs: tl.constexpr = C_pad
+ stride_w3_kout: tl.constexpr = 9 * C_pad
+ # Y layout: [N, P, Q, K_out] contiguous NHWC (stride_y_k=1 hardcoded in store logic)
+ stride_y_q: tl.constexpr = K_out
+ stride_y_p: tl.constexpr = Q * K_out
+ stride_y_n: tl.constexpr = P * Q * K_out
+
+ pid = tl.program_id(axis=0)
+
+ num_pid_m = tl.cdiv(M_total, BLOCK_M)
+ num_pid_n = tl.cdiv(K_out, BLOCK_N)
+
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m >= num_pid_m:
+ return
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ kout_mask = offs_n < K_out
+
+ # Decode (n, p, q) from linear index
+ n_idx = offs_m[:, None] // (P * Q)
+ pq = offs_m[:, None] % (P * Q)
+ p_idx = pq // Q
+ q_idx = pq % Q
+ n_valid = n_idx < N
+
+ # Precompute base positions
+ base_oh = p_idx * stride_h - pad_h
+ base_ow = q_idx * stride_w - pad_w
+ stride_dh = dil_h * stride_x_h
+ stride_dw = dil_w * stride_x_w
+ x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w
+
+ # Weight base: W3[K_out, 9, C_pad]
+ w_base = W3 + offs_n[None, :] * stride_w3_kout
+
+ Y_ptrs = (
+ Y
+ + n_idx * stride_y_n
+ + offs_n[None, :]
+ + p_idx * stride_y_p
+ + q_idx * stride_y_q
+ )
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ offs_c = tl.arange(0, BLOCK_C)
+
+ for r in tl.static_range(3):
+ oh = base_oh + r * dil_h
+ valid_oh = n_valid & (oh >= 0) & (oh < H)
+ x_off_r = r * stride_dh
+ for s in tl.static_range(3):
+ rs_idx = r * 3 + s
+ ow = base_ow + s * dil_w
+ valid = valid_oh & (ow >= 0) & (ow < W_in)
+
+ for c0 in range(0, C_pad, BLOCK_C):
+ c_offs = c0 + offs_c
+ c_mask = c_offs < C
+
+ x_ptrs = x_base + c_offs[None, :] + x_off_r + s * stride_dw
+ w_ptrs = w_base + rs_idx * stride_w3_rs + c_offs[:, None] * stride_w3_c
+
+ x_tile = tl.load(x_ptrs, mask=valid & c_mask[None, :], other=0.0)
+ w_tile = tl.load(
+ w_ptrs, mask=c_mask[:, None] & kout_mask[None, :], other=0.0
+ )
+ acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32)
+
+ # Epilogue: bias + activation + store
+ if HAS_BIAS:
+ b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0)
+ acc += b[None, :]
+
+ if ACT_TYPE == 1:
+ acc = tl.maximum(acc, 0)
+ elif ACT_TYPE == 2:
+ acc = tl.minimum(tl.maximum(acc, 0), 6)
+ elif ACT_TYPE == 3:
+ acc = (
+ 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc)))
+ )
+
+ tl.store(
+ Y_ptrs,
+ acc,
+ mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]),
+ )
+
+
+@triton.autotune(
+ configs=AUTOTUNE_3x3_CBLOCKED_CONFIGS,
+ key=["M_total", "K_out", "C_pad"],
+ reset_to_zero=["Y"],
+ warmup=50,
+ rep=200,
+ cache_results=True,
+)
+@triton.jit
+def _conv2d_3x3_cblocked_kernel(
+ X,
+ W3,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ K_out: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ C_pad: tl.constexpr,
+ Cb: tl.constexpr,
+ stride_h: tl.constexpr,
+ stride_w: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ dil_h: tl.constexpr,
+ dil_w: tl.constexpr,
+ M_total: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+):
+ """Specialized 3x3 kernel for channel-blocked [N, C_blocks, H, W, Cb] input.
+ stride_c_local=1 is hardcoded so the compiler emits coalesced vector loads."""
+ # X layout: [N, C_blocks, H, W_in, Cb] where C_blocks = C_pad // Cb
+ stride_x_w: tl.constexpr = Cb
+ stride_x_h: tl.constexpr = W_in * Cb
+ stride_x_cblock: tl.constexpr = H * W_in * Cb
+ stride_x_n: tl.constexpr = (C_pad // Cb) * H * W_in * Cb
+ # W3 layout: [K_out, 9, C_pad] contiguous
+ stride_w3_c: tl.constexpr = 1
+ stride_w3_rs: tl.constexpr = C_pad
+ stride_w3_kout: tl.constexpr = 9 * C_pad
+ # Y layout: [N, K_out, P, Q] contiguous NCHW
+ stride_y_q: tl.constexpr = 1
+ stride_y_p: tl.constexpr = Q
+ stride_y_k: tl.constexpr = P * Q
+ stride_y_n: tl.constexpr = K_out * P * Q
+
+ pid = tl.program_id(axis=0)
+
+ num_pid_m = tl.cdiv(M_total, BLOCK_M)
+ num_pid_n = tl.cdiv(K_out, BLOCK_N)
+
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m >= num_pid_m:
+ return
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ kout_mask = offs_n < K_out
+
+ # Decode (n, p, q) from linear index
+ n_idx = offs_m[:, None] // (P * Q)
+ pq = offs_m[:, None] % (P * Q)
+ p_idx = pq // Q
+ q_idx = pq % Q
+ n_valid = n_idx < N
+
+ # Precompute base positions
+ base_oh = p_idx * stride_h - pad_h
+ base_ow = q_idx * stride_w - pad_w
+ stride_dh = dil_h * stride_x_h
+ stride_dw = dil_w * stride_x_w
+
+ # x_base for channel-blocked layout: X[n, cblock, h, w, c_local]
+ # base pointer accounts for n, h, w
+ x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w
+
+ # Weight base: W3[K_out, 9, C_pad]
+ w_base = W3 + offs_n[None, :] * stride_w3_kout
+
+ # Y pointers
+ Y_ptrs = (
+ Y
+ + n_idx * stride_y_n
+ + offs_n[None, :] * stride_y_k
+ + p_idx * stride_y_p
+ + q_idx * stride_y_q
+ )
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ offs_c = tl.arange(0, BLOCK_C)
+
+ for r in tl.static_range(3):
+ oh = base_oh + r * dil_h
+ valid_oh = n_valid & (oh >= 0) & (oh < H)
+ x_off_r = r * stride_dh
+ for s in tl.static_range(3):
+ rs_idx = r * 3 + s
+ ow = base_ow + s * dil_w
+ valid = valid_oh & (ow >= 0) & (ow < W_in)
+
+ for c0 in range(0, C_pad, BLOCK_C):
+ c_offs = c0 + offs_c
+ c_mask = c_offs < C
+
+ # Compute cblock index and local offset within block
+ cblock_idx = c_offs // Cb
+ c_local = c_offs % Cb
+
+ x_ptrs = (
+ x_base
+ + cblock_idx[None, :] * stride_x_cblock
+ + c_local[None, :]
+ + x_off_r
+ + s * stride_dw
+ )
+ w_ptrs = w_base + rs_idx * stride_w3_rs + c_offs[:, None] * stride_w3_c
+
+ x_tile = tl.load(x_ptrs, mask=valid & c_mask[None, :], other=0.0)
+ w_tile = tl.load(
+ w_ptrs, mask=c_mask[:, None] & kout_mask[None, :], other=0.0
+ )
+ acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32)
+
+ # Epilogue: bias + activation + store
+ if HAS_BIAS:
+ b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0)
+ acc += b[None, :]
+
+ if ACT_TYPE == 1:
+ acc = tl.maximum(acc, 0)
+ elif ACT_TYPE == 2:
+ acc = tl.minimum(tl.maximum(acc, 0), 6)
+ elif ACT_TYPE == 3:
+ acc = (
+ 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc)))
+ )
+
+ tl.store(
+ Y_ptrs,
+ acc,
+ mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]),
+ )
diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py b/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py
new file mode 100644
index 0000000000..728d4bad59
--- /dev/null
+++ b/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py
@@ -0,0 +1,1480 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import triton
+import triton.language as tl
+from .helpers import (
+ _tanh,
+ AUTOTUNE_WINO4_INPUT_CONFIGS,
+ AUTOTUNE_WINO_GEMM_CONFIGS,
+ AUTOTUNE_WINO4_OUTPUT_CONFIGS,
+ AUTOTUNE_FUSED_F4X3_CONFIGS,
+)
+
+
+@triton.autotune(
+ configs=AUTOTUNE_WINO4_INPUT_CONFIGS,
+ key=["T", "C_pad"],
+ cache_results=True,
+)
+@triton.jit
+def _winograd_f4x3_input_transform_kernel(
+ X,
+ V,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ C_pad: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ tile_H: tl.constexpr,
+ tile_W: tl.constexpr,
+ T: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+ INPUT_DTYPE: tl.constexpr = tl.float16,
+ LAYOUT: tl.constexpr = 0,
+):
+ # X layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC
+ if LAYOUT == 0:
+ stride_x_w: tl.constexpr = 1
+ stride_x_h: tl.constexpr = W_in
+ stride_x_c: tl.constexpr = H * W_in
+ stride_x_n: tl.constexpr = C * H * W_in
+ else:
+ stride_x_w: tl.constexpr = C
+ stride_x_h: tl.constexpr = W_in * C
+ stride_x_c: tl.constexpr = 1
+ stride_x_n: tl.constexpr = H * W_in * C
+ # V layout: [36, T, C_pad] contiguous
+ stride_v_c: tl.constexpr = 1
+ stride_v_tile: tl.constexpr = C_pad
+ stride_v_alpha: tl.constexpr = T * C_pad
+
+ tile_idx = tl.program_id(0)
+ c_block = tl.program_id(1)
+
+ n = tile_idx // (tile_H * tile_W)
+ rem = tile_idx % (tile_H * tile_W)
+ th = rem // tile_W
+ tw = rem % tile_W
+
+ h_start = th * 4 - pad_h
+ w_start = tw * 4 - pad_w
+
+ offs_c = c_block * BLOCK_C + tl.arange(0, BLOCK_C)
+ c_mask = offs_c < C
+
+ base = X + n * stride_x_n + offs_c * stride_x_c
+ n_valid = n < N
+
+ # Load 6x6 patch
+ d00 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d01 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d02 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d03 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d04 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d05 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d10 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d11 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d12 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d13 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d14 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d15 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d20 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d21 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d22 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d23 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d24 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d25 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d30 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d31 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d32 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d33 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d34 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d35 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d40 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d41 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d42 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d43 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d44 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d45 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d50 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d51 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d52 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d53 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d54 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d55 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+
+ for r in tl.static_range(6):
+ h = h_start + r
+ h_valid = n_valid & (h >= 0) & (h < H)
+ for s in tl.static_range(6):
+ w = w_start + s
+ valid = h_valid & (w >= 0) & (w < W_in)
+ ptr = base + h * stride_x_h + w * stride_x_w
+ val = tl.load(ptr, mask=valid & c_mask, other=0.0)
+ if r == 0:
+ if s == 0:
+ d00 = val
+ elif s == 1:
+ d01 = val
+ elif s == 2:
+ d02 = val
+ elif s == 3:
+ d03 = val
+ elif s == 4:
+ d04 = val
+ else:
+ d05 = val
+ elif r == 1:
+ if s == 0:
+ d10 = val
+ elif s == 1:
+ d11 = val
+ elif s == 2:
+ d12 = val
+ elif s == 3:
+ d13 = val
+ elif s == 4:
+ d14 = val
+ else:
+ d15 = val
+ elif r == 2:
+ if s == 0:
+ d20 = val
+ elif s == 1:
+ d21 = val
+ elif s == 2:
+ d22 = val
+ elif s == 3:
+ d23 = val
+ elif s == 4:
+ d24 = val
+ else:
+ d25 = val
+ elif r == 3:
+ if s == 0:
+ d30 = val
+ elif s == 1:
+ d31 = val
+ elif s == 2:
+ d32 = val
+ elif s == 3:
+ d33 = val
+ elif s == 4:
+ d34 = val
+ else:
+ d35 = val
+ elif r == 4:
+ if s == 0:
+ d40 = val
+ elif s == 1:
+ d41 = val
+ elif s == 2:
+ d42 = val
+ elif s == 3:
+ d43 = val
+ elif s == 4:
+ d44 = val
+ else:
+ d45 = val
+ else:
+ if s == 0:
+ d50 = val
+ elif s == 1:
+ d51 = val
+ elif s == 2:
+ d52 = val
+ elif s == 3:
+ d53 = val
+ elif s == 4:
+ d54 = val
+ else:
+ d55 = val
+
+ # B^T column transform (6x6):
+ # B^T = [[ 4, 0, -5, 0, 1, 0],
+ # [ 0, -4, -4, 1, 1, 0],
+ # [ 0, 4, -4, -1, 1, 0],
+ # [ 0, -2, -1, 2, 1, 0],
+ # [ 0, 2, -1, -2, 1, 0],
+ # [ 0, 4, 0, -5, 0, 1]]
+ # Apply to each column (compute t[row][col] from d[row][col])
+ # Float32 for transform arithmetic to avoid fp16 overflow with multipliers like 4, 5
+ d00f = d00.to(tl.float32)
+ d01f = d01.to(tl.float32)
+ d02f = d02.to(tl.float32)
+ d03f = d03.to(tl.float32)
+ d04f = d04.to(tl.float32)
+ d05f = d05.to(tl.float32)
+ d10f = d10.to(tl.float32)
+ d11f = d11.to(tl.float32)
+ d12f = d12.to(tl.float32)
+ d13f = d13.to(tl.float32)
+ d14f = d14.to(tl.float32)
+ d15f = d15.to(tl.float32)
+ d20f = d20.to(tl.float32)
+ d21f = d21.to(tl.float32)
+ d22f = d22.to(tl.float32)
+ d23f = d23.to(tl.float32)
+ d24f = d24.to(tl.float32)
+ d25f = d25.to(tl.float32)
+ d30f = d30.to(tl.float32)
+ d31f = d31.to(tl.float32)
+ d32f = d32.to(tl.float32)
+ d33f = d33.to(tl.float32)
+ d34f = d34.to(tl.float32)
+ d35f = d35.to(tl.float32)
+ d40f = d40.to(tl.float32)
+ d41f = d41.to(tl.float32)
+ d42f = d42.to(tl.float32)
+ d43f = d43.to(tl.float32)
+ d44f = d44.to(tl.float32)
+ d45f = d45.to(tl.float32)
+ d50f = d50.to(tl.float32)
+ d51f = d51.to(tl.float32)
+ d52f = d52.to(tl.float32)
+ d53f = d53.to(tl.float32)
+ d54f = d54.to(tl.float32)
+ d55f = d55.to(tl.float32)
+
+ # Column transform: for each column s, t[row][s] = B^T @ d[:,s]
+ t00 = 4 * d00f - 5 * d20f + d40f
+ t01 = 4 * d01f - 5 * d21f + d41f
+ t02 = 4 * d02f - 5 * d22f + d42f
+ t03 = 4 * d03f - 5 * d23f + d43f
+ t04 = 4 * d04f - 5 * d24f + d44f
+ t05 = 4 * d05f - 5 * d25f + d45f
+
+ t10 = -4 * d10f - 4 * d20f + d30f + d40f
+ t11 = -4 * d11f - 4 * d21f + d31f + d41f
+ t12 = -4 * d12f - 4 * d22f + d32f + d42f
+ t13 = -4 * d13f - 4 * d23f + d33f + d43f
+ t14 = -4 * d14f - 4 * d24f + d34f + d44f
+ t15 = -4 * d15f - 4 * d25f + d35f + d45f
+
+ t20 = 4 * d10f - 4 * d20f - d30f + d40f
+ t21 = 4 * d11f - 4 * d21f - d31f + d41f
+ t22 = 4 * d12f - 4 * d22f - d32f + d42f
+ t23 = 4 * d13f - 4 * d23f - d33f + d43f
+ t24 = 4 * d14f - 4 * d24f - d34f + d44f
+ t25 = 4 * d15f - 4 * d25f - d35f + d45f
+
+ t30 = -2 * d10f - d20f + 2 * d30f + d40f
+ t31 = -2 * d11f - d21f + 2 * d31f + d41f
+ t32 = -2 * d12f - d22f + 2 * d32f + d42f
+ t33 = -2 * d13f - d23f + 2 * d33f + d43f
+ t34 = -2 * d14f - d24f + 2 * d34f + d44f
+ t35 = -2 * d15f - d25f + 2 * d35f + d45f
+
+ t40 = 2 * d10f - d20f - 2 * d30f + d40f
+ t41 = 2 * d11f - d21f - 2 * d31f + d41f
+ t42 = 2 * d12f - d22f - 2 * d32f + d42f
+ t43 = 2 * d13f - d23f - 2 * d33f + d43f
+ t44 = 2 * d14f - d24f - 2 * d34f + d44f
+ t45 = 2 * d15f - d25f - 2 * d35f + d45f
+
+ t50 = 4 * d10f - 5 * d30f + d50f
+ t51 = 4 * d11f - 5 * d31f + d51f
+ t52 = 4 * d12f - 5 * d32f + d52f
+ t53 = 4 * d13f - 5 * d33f + d53f
+ t54 = 4 * d14f - 5 * d34f + d54f
+ t55 = 4 * d15f - 5 * d35f + d55f
+
+ # Row transform: v[r][col] = B^T applied to row t[r][:]
+ v00 = 4 * t00 - 5 * t02 + t04
+ v01 = -4 * t01 - 4 * t02 + t03 + t04
+ v02 = 4 * t01 - 4 * t02 - t03 + t04
+ v03 = -2 * t01 - t02 + 2 * t03 + t04
+ v04 = 2 * t01 - t02 - 2 * t03 + t04
+ v05 = 4 * t01 - 5 * t03 + t05
+
+ v10 = 4 * t10 - 5 * t12 + t14
+ v11 = -4 * t11 - 4 * t12 + t13 + t14
+ v12 = 4 * t11 - 4 * t12 - t13 + t14
+ v13 = -2 * t11 - t12 + 2 * t13 + t14
+ v14 = 2 * t11 - t12 - 2 * t13 + t14
+ v15 = 4 * t11 - 5 * t13 + t15
+
+ v20 = 4 * t20 - 5 * t22 + t24
+ v21 = -4 * t21 - 4 * t22 + t23 + t24
+ v22 = 4 * t21 - 4 * t22 - t23 + t24
+ v23 = -2 * t21 - t22 + 2 * t23 + t24
+ v24 = 2 * t21 - t22 - 2 * t23 + t24
+ v25 = 4 * t21 - 5 * t23 + t25
+
+ v30 = 4 * t30 - 5 * t32 + t34
+ v31 = -4 * t31 - 4 * t32 + t33 + t34
+ v32 = 4 * t31 - 4 * t32 - t33 + t34
+ v33 = -2 * t31 - t32 + 2 * t33 + t34
+ v34 = 2 * t31 - t32 - 2 * t33 + t34
+ v35 = 4 * t31 - 5 * t33 + t35
+
+ v40 = 4 * t40 - 5 * t42 + t44
+ v41 = -4 * t41 - 4 * t42 + t43 + t44
+ v42 = 4 * t41 - 4 * t42 - t43 + t44
+ v43 = -2 * t41 - t42 + 2 * t43 + t44
+ v44 = 2 * t41 - t42 - 2 * t43 + t44
+ v45 = 4 * t41 - 5 * t43 + t45
+
+ v50 = 4 * t50 - 5 * t52 + t54
+ v51 = -4 * t51 - 4 * t52 + t53 + t54
+ v52 = 4 * t51 - 4 * t52 - t53 + t54
+ v53 = -2 * t51 - t52 + 2 * t53 + t54
+ v54 = 2 * t51 - t52 - 2 * t53 + t54
+ v55 = 4 * t51 - 5 * t53 + t55
+
+ v_base = V + tile_idx * stride_v_tile + offs_c * stride_v_c
+ c_store_mask = offs_c < C_pad
+
+ tl.store(v_base + 0 * stride_v_alpha, v00, mask=c_store_mask)
+ tl.store(v_base + 1 * stride_v_alpha, v01, mask=c_store_mask)
+ tl.store(v_base + 2 * stride_v_alpha, v02, mask=c_store_mask)
+ tl.store(v_base + 3 * stride_v_alpha, v03, mask=c_store_mask)
+ tl.store(v_base + 4 * stride_v_alpha, v04, mask=c_store_mask)
+ tl.store(v_base + 5 * stride_v_alpha, v05, mask=c_store_mask)
+ tl.store(v_base + 6 * stride_v_alpha, v10, mask=c_store_mask)
+ tl.store(v_base + 7 * stride_v_alpha, v11, mask=c_store_mask)
+ tl.store(v_base + 8 * stride_v_alpha, v12, mask=c_store_mask)
+ tl.store(v_base + 9 * stride_v_alpha, v13, mask=c_store_mask)
+ tl.store(v_base + 10 * stride_v_alpha, v14, mask=c_store_mask)
+ tl.store(v_base + 11 * stride_v_alpha, v15, mask=c_store_mask)
+ tl.store(v_base + 12 * stride_v_alpha, v20, mask=c_store_mask)
+ tl.store(v_base + 13 * stride_v_alpha, v21, mask=c_store_mask)
+ tl.store(v_base + 14 * stride_v_alpha, v22, mask=c_store_mask)
+ tl.store(v_base + 15 * stride_v_alpha, v23, mask=c_store_mask)
+ tl.store(v_base + 16 * stride_v_alpha, v24, mask=c_store_mask)
+ tl.store(v_base + 17 * stride_v_alpha, v25, mask=c_store_mask)
+ tl.store(v_base + 18 * stride_v_alpha, v30, mask=c_store_mask)
+ tl.store(v_base + 19 * stride_v_alpha, v31, mask=c_store_mask)
+ tl.store(v_base + 20 * stride_v_alpha, v32, mask=c_store_mask)
+ tl.store(v_base + 21 * stride_v_alpha, v33, mask=c_store_mask)
+ tl.store(v_base + 22 * stride_v_alpha, v34, mask=c_store_mask)
+ tl.store(v_base + 23 * stride_v_alpha, v35, mask=c_store_mask)
+ tl.store(v_base + 24 * stride_v_alpha, v40, mask=c_store_mask)
+ tl.store(v_base + 25 * stride_v_alpha, v41, mask=c_store_mask)
+ tl.store(v_base + 26 * stride_v_alpha, v42, mask=c_store_mask)
+ tl.store(v_base + 27 * stride_v_alpha, v43, mask=c_store_mask)
+ tl.store(v_base + 28 * stride_v_alpha, v44, mask=c_store_mask)
+ tl.store(v_base + 29 * stride_v_alpha, v45, mask=c_store_mask)
+ tl.store(v_base + 30 * stride_v_alpha, v50, mask=c_store_mask)
+ tl.store(v_base + 31 * stride_v_alpha, v51, mask=c_store_mask)
+ tl.store(v_base + 32 * stride_v_alpha, v52, mask=c_store_mask)
+ tl.store(v_base + 33 * stride_v_alpha, v53, mask=c_store_mask)
+ tl.store(v_base + 34 * stride_v_alpha, v54, mask=c_store_mask)
+ tl.store(v_base + 35 * stride_v_alpha, v55, mask=c_store_mask)
+
+
+@triton.autotune(
+ configs=AUTOTUNE_WINO4_INPUT_CONFIGS,
+ key=["T", "C_pad"],
+ cache_results=True,
+)
+@triton.jit
+def _winograd_f4x3_cblocked_input_transform_kernel(
+ X,
+ V,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ C_pad: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ tile_H: tl.constexpr,
+ tile_W: tl.constexpr,
+ T: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ Cb: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+ INPUT_DTYPE: tl.constexpr = tl.float16,
+):
+ # X layout: [N, C_blocks, H, W_in, Cb] where C_blocks = C_pad // Cb
+ stride_x_w: tl.constexpr = Cb
+ stride_x_h: tl.constexpr = W_in * Cb
+ stride_x_cblock: tl.constexpr = H * W_in * Cb
+ stride_x_n: tl.constexpr = (C_pad // Cb) * H * W_in * Cb
+ # V layout: [36, T, C_pad] contiguous
+ stride_v_c: tl.constexpr = 1
+ stride_v_tile: tl.constexpr = C_pad
+ stride_v_alpha: tl.constexpr = T * C_pad
+
+ tile_idx = tl.program_id(0)
+ c_block = tl.program_id(1)
+
+ n = tile_idx // (tile_H * tile_W)
+ rem = tile_idx % (tile_H * tile_W)
+ th = rem // tile_W
+ tw = rem % tile_W
+
+ h_start = th * 4 - pad_h
+ w_start = tw * 4 - pad_w
+
+ offs_c = c_block * BLOCK_C + tl.arange(0, BLOCK_C)
+ c_mask = offs_c < C
+
+ # NCHWc addressing: cblock_idx = offs_c // Cb, c_local = offs_c % Cb
+ cblock_idx = offs_c // Cb
+ c_local = offs_c % Cb
+ base = X + n * stride_x_n + cblock_idx * stride_x_cblock + c_local
+ n_valid = n < N
+
+ # Load 6x6 patch — 36 values per channel
+ d00 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d01 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d02 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d03 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d04 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d05 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d10 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d11 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d12 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d13 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d14 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d15 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d20 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d21 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d22 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d23 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d24 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d25 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d30 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d31 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d32 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d33 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d34 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d35 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d40 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d41 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d42 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d43 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d44 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d45 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d50 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d51 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d52 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d53 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d54 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+ d55 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE)
+
+ for r in tl.static_range(6):
+ h = h_start + r
+ h_valid = n_valid & (h >= 0) & (h < H)
+ for s in tl.static_range(6):
+ w = w_start + s
+ valid = h_valid & (w >= 0) & (w < W_in)
+ ptr = base + h * stride_x_h + w * stride_x_w
+ val = tl.load(ptr, mask=valid & c_mask, other=0.0)
+ if r == 0:
+ if s == 0:
+ d00 = val
+ elif s == 1:
+ d01 = val
+ elif s == 2:
+ d02 = val
+ elif s == 3:
+ d03 = val
+ elif s == 4:
+ d04 = val
+ else:
+ d05 = val
+ elif r == 1:
+ if s == 0:
+ d10 = val
+ elif s == 1:
+ d11 = val
+ elif s == 2:
+ d12 = val
+ elif s == 3:
+ d13 = val
+ elif s == 4:
+ d14 = val
+ else:
+ d15 = val
+ elif r == 2:
+ if s == 0:
+ d20 = val
+ elif s == 1:
+ d21 = val
+ elif s == 2:
+ d22 = val
+ elif s == 3:
+ d23 = val
+ elif s == 4:
+ d24 = val
+ else:
+ d25 = val
+ elif r == 3:
+ if s == 0:
+ d30 = val
+ elif s == 1:
+ d31 = val
+ elif s == 2:
+ d32 = val
+ elif s == 3:
+ d33 = val
+ elif s == 4:
+ d34 = val
+ else:
+ d35 = val
+ elif r == 4:
+ if s == 0:
+ d40 = val
+ elif s == 1:
+ d41 = val
+ elif s == 2:
+ d42 = val
+ elif s == 3:
+ d43 = val
+ elif s == 4:
+ d44 = val
+ else:
+ d45 = val
+ else:
+ if s == 0:
+ d50 = val
+ elif s == 1:
+ d51 = val
+ elif s == 2:
+ d52 = val
+ elif s == 3:
+ d53 = val
+ elif s == 4:
+ d54 = val
+ else:
+ d55 = val
+
+ d00f = d00.to(tl.float32)
+ d01f = d01.to(tl.float32)
+ d02f = d02.to(tl.float32)
+ d03f = d03.to(tl.float32)
+ d04f = d04.to(tl.float32)
+ d05f = d05.to(tl.float32)
+ d10f = d10.to(tl.float32)
+ d11f = d11.to(tl.float32)
+ d12f = d12.to(tl.float32)
+ d13f = d13.to(tl.float32)
+ d14f = d14.to(tl.float32)
+ d15f = d15.to(tl.float32)
+ d20f = d20.to(tl.float32)
+ d21f = d21.to(tl.float32)
+ d22f = d22.to(tl.float32)
+ d23f = d23.to(tl.float32)
+ d24f = d24.to(tl.float32)
+ d25f = d25.to(tl.float32)
+ d30f = d30.to(tl.float32)
+ d31f = d31.to(tl.float32)
+ d32f = d32.to(tl.float32)
+ d33f = d33.to(tl.float32)
+ d34f = d34.to(tl.float32)
+ d35f = d35.to(tl.float32)
+ d40f = d40.to(tl.float32)
+ d41f = d41.to(tl.float32)
+ d42f = d42.to(tl.float32)
+ d43f = d43.to(tl.float32)
+ d44f = d44.to(tl.float32)
+ d45f = d45.to(tl.float32)
+ d50f = d50.to(tl.float32)
+ d51f = d51.to(tl.float32)
+ d52f = d52.to(tl.float32)
+ d53f = d53.to(tl.float32)
+ d54f = d54.to(tl.float32)
+ d55f = d55.to(tl.float32)
+
+ t00 = 4 * d00f - 5 * d20f + d40f
+ t01 = 4 * d01f - 5 * d21f + d41f
+ t02 = 4 * d02f - 5 * d22f + d42f
+ t03 = 4 * d03f - 5 * d23f + d43f
+ t04 = 4 * d04f - 5 * d24f + d44f
+ t05 = 4 * d05f - 5 * d25f + d45f
+
+ t10 = -4 * d10f - 4 * d20f + d30f + d40f
+ t11 = -4 * d11f - 4 * d21f + d31f + d41f
+ t12 = -4 * d12f - 4 * d22f + d32f + d42f
+ t13 = -4 * d13f - 4 * d23f + d33f + d43f
+ t14 = -4 * d14f - 4 * d24f + d34f + d44f
+ t15 = -4 * d15f - 4 * d25f + d35f + d45f
+
+ t20 = 4 * d10f - 4 * d20f - d30f + d40f
+ t21 = 4 * d11f - 4 * d21f - d31f + d41f
+ t22 = 4 * d12f - 4 * d22f - d32f + d42f
+ t23 = 4 * d13f - 4 * d23f - d33f + d43f
+ t24 = 4 * d14f - 4 * d24f - d34f + d44f
+ t25 = 4 * d15f - 4 * d25f - d35f + d45f
+
+ t30 = -2 * d10f - d20f + 2 * d30f + d40f
+ t31 = -2 * d11f - d21f + 2 * d31f + d41f
+ t32 = -2 * d12f - d22f + 2 * d32f + d42f
+ t33 = -2 * d13f - d23f + 2 * d33f + d43f
+ t34 = -2 * d14f - d24f + 2 * d34f + d44f
+ t35 = -2 * d15f - d25f + 2 * d35f + d45f
+
+ t40 = 2 * d10f - d20f - 2 * d30f + d40f
+ t41 = 2 * d11f - d21f - 2 * d31f + d41f
+ t42 = 2 * d12f - d22f - 2 * d32f + d42f
+ t43 = 2 * d13f - d23f - 2 * d33f + d43f
+ t44 = 2 * d14f - d24f - 2 * d34f + d44f
+ t45 = 2 * d15f - d25f - 2 * d35f + d45f
+
+ t50 = 4 * d10f - 5 * d30f + d50f
+ t51 = 4 * d11f - 5 * d31f + d51f
+ t52 = 4 * d12f - 5 * d32f + d52f
+ t53 = 4 * d13f - 5 * d33f + d53f
+ t54 = 4 * d14f - 5 * d34f + d54f
+ t55 = 4 * d15f - 5 * d35f + d55f
+
+ v00 = 4 * t00 - 5 * t02 + t04
+ v01 = -4 * t01 - 4 * t02 + t03 + t04
+ v02 = 4 * t01 - 4 * t02 - t03 + t04
+ v03 = -2 * t01 - t02 + 2 * t03 + t04
+ v04 = 2 * t01 - t02 - 2 * t03 + t04
+ v05 = 4 * t01 - 5 * t03 + t05
+
+ v10 = 4 * t10 - 5 * t12 + t14
+ v11 = -4 * t11 - 4 * t12 + t13 + t14
+ v12 = 4 * t11 - 4 * t12 - t13 + t14
+ v13 = -2 * t11 - t12 + 2 * t13 + t14
+ v14 = 2 * t11 - t12 - 2 * t13 + t14
+ v15 = 4 * t11 - 5 * t13 + t15
+
+ v20 = 4 * t20 - 5 * t22 + t24
+ v21 = -4 * t21 - 4 * t22 + t23 + t24
+ v22 = 4 * t21 - 4 * t22 - t23 + t24
+ v23 = -2 * t21 - t22 + 2 * t23 + t24
+ v24 = 2 * t21 - t22 - 2 * t23 + t24
+ v25 = 4 * t21 - 5 * t23 + t25
+
+ v30 = 4 * t30 - 5 * t32 + t34
+ v31 = -4 * t31 - 4 * t32 + t33 + t34
+ v32 = 4 * t31 - 4 * t32 - t33 + t34
+ v33 = -2 * t31 - t32 + 2 * t33 + t34
+ v34 = 2 * t31 - t32 - 2 * t33 + t34
+ v35 = 4 * t31 - 5 * t33 + t35
+
+ v40 = 4 * t40 - 5 * t42 + t44
+ v41 = -4 * t41 - 4 * t42 + t43 + t44
+ v42 = 4 * t41 - 4 * t42 - t43 + t44
+ v43 = -2 * t41 - t42 + 2 * t43 + t44
+ v44 = 2 * t41 - t42 - 2 * t43 + t44
+ v45 = 4 * t41 - 5 * t43 + t45
+
+ v50 = 4 * t50 - 5 * t52 + t54
+ v51 = -4 * t51 - 4 * t52 + t53 + t54
+ v52 = 4 * t51 - 4 * t52 - t53 + t54
+ v53 = -2 * t51 - t52 + 2 * t53 + t54
+ v54 = 2 * t51 - t52 - 2 * t53 + t54
+ v55 = 4 * t51 - 5 * t53 + t55
+
+ v_base = V + tile_idx * stride_v_tile + offs_c * stride_v_c
+ c_store_mask = offs_c < C_pad
+
+ tl.store(v_base + 0 * stride_v_alpha, v00, mask=c_store_mask)
+ tl.store(v_base + 1 * stride_v_alpha, v01, mask=c_store_mask)
+ tl.store(v_base + 2 * stride_v_alpha, v02, mask=c_store_mask)
+ tl.store(v_base + 3 * stride_v_alpha, v03, mask=c_store_mask)
+ tl.store(v_base + 4 * stride_v_alpha, v04, mask=c_store_mask)
+ tl.store(v_base + 5 * stride_v_alpha, v05, mask=c_store_mask)
+ tl.store(v_base + 6 * stride_v_alpha, v10, mask=c_store_mask)
+ tl.store(v_base + 7 * stride_v_alpha, v11, mask=c_store_mask)
+ tl.store(v_base + 8 * stride_v_alpha, v12, mask=c_store_mask)
+ tl.store(v_base + 9 * stride_v_alpha, v13, mask=c_store_mask)
+ tl.store(v_base + 10 * stride_v_alpha, v14, mask=c_store_mask)
+ tl.store(v_base + 11 * stride_v_alpha, v15, mask=c_store_mask)
+ tl.store(v_base + 12 * stride_v_alpha, v20, mask=c_store_mask)
+ tl.store(v_base + 13 * stride_v_alpha, v21, mask=c_store_mask)
+ tl.store(v_base + 14 * stride_v_alpha, v22, mask=c_store_mask)
+ tl.store(v_base + 15 * stride_v_alpha, v23, mask=c_store_mask)
+ tl.store(v_base + 16 * stride_v_alpha, v24, mask=c_store_mask)
+ tl.store(v_base + 17 * stride_v_alpha, v25, mask=c_store_mask)
+ tl.store(v_base + 18 * stride_v_alpha, v30, mask=c_store_mask)
+ tl.store(v_base + 19 * stride_v_alpha, v31, mask=c_store_mask)
+ tl.store(v_base + 20 * stride_v_alpha, v32, mask=c_store_mask)
+ tl.store(v_base + 21 * stride_v_alpha, v33, mask=c_store_mask)
+ tl.store(v_base + 22 * stride_v_alpha, v34, mask=c_store_mask)
+ tl.store(v_base + 23 * stride_v_alpha, v35, mask=c_store_mask)
+ tl.store(v_base + 24 * stride_v_alpha, v40, mask=c_store_mask)
+ tl.store(v_base + 25 * stride_v_alpha, v41, mask=c_store_mask)
+ tl.store(v_base + 26 * stride_v_alpha, v42, mask=c_store_mask)
+ tl.store(v_base + 27 * stride_v_alpha, v43, mask=c_store_mask)
+ tl.store(v_base + 28 * stride_v_alpha, v44, mask=c_store_mask)
+ tl.store(v_base + 29 * stride_v_alpha, v45, mask=c_store_mask)
+ tl.store(v_base + 30 * stride_v_alpha, v50, mask=c_store_mask)
+ tl.store(v_base + 31 * stride_v_alpha, v51, mask=c_store_mask)
+ tl.store(v_base + 32 * stride_v_alpha, v52, mask=c_store_mask)
+ tl.store(v_base + 33 * stride_v_alpha, v53, mask=c_store_mask)
+ tl.store(v_base + 34 * stride_v_alpha, v54, mask=c_store_mask)
+ tl.store(v_base + 35 * stride_v_alpha, v55, mask=c_store_mask)
+
+
+@triton.autotune(
+ configs=AUTOTUNE_WINO_GEMM_CONFIGS,
+ key=["T", "K_out", "C_pad"],
+ cache_results=True,
+)
+@triton.jit
+def _winograd_f4x3_batched_gemm_kernel(
+ V,
+ U,
+ M_out,
+ T: tl.constexpr,
+ K_out: tl.constexpr,
+ C_pad: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+):
+ """Batched GEMM: M[alpha] = V[alpha] @ U[alpha]^T, alpha in [0..36)"""
+ # V layout: [36, T, C_pad] contiguous
+ stride_v_c: tl.constexpr = 1
+ stride_v_tile: tl.constexpr = C_pad
+ stride_v_alpha: tl.constexpr = T * C_pad
+ # U layout: [36, K_out, C_pad] contiguous
+ stride_u_c: tl.constexpr = 1
+ stride_u_k: tl.constexpr = C_pad
+ stride_u_alpha: tl.constexpr = K_out * C_pad
+ # M layout: [36, T, K_out] contiguous
+ stride_m_k: tl.constexpr = 1
+ stride_m_tile: tl.constexpr = K_out
+ stride_m_alpha: tl.constexpr = T * K_out
+
+ pid = tl.program_id(0)
+ alpha = tl.program_id(1)
+
+ num_pid_m = tl.cdiv(T, BLOCK_M)
+ num_pid_n = tl.cdiv(K_out, BLOCK_N)
+
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m >= num_pid_m or pid_n >= num_pid_n:
+ return
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ v_base = V + alpha * stride_v_alpha
+ u_base = U + alpha * stride_u_alpha
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k0 in range(0, C_pad, BLOCK_K):
+ k_offs = k0 + offs_k
+
+ v_ptrs = v_base + offs_m[:, None] * stride_v_tile + k_offs[None, :] * stride_v_c
+ v_mask = (offs_m[:, None] < T) & (k_offs[None, :] < C_pad)
+ v_tile = tl.load(v_ptrs, mask=v_mask, other=0.0)
+
+ u_ptrs = u_base + offs_n[:, None] * stride_u_k + k_offs[None, :] * stride_u_c
+ u_mask = (offs_n[:, None] < K_out) & (k_offs[None, :] < C_pad)
+ u_tile = tl.load(u_ptrs, mask=u_mask, other=0.0)
+
+ acc += tl.dot(v_tile, tl.trans(u_tile), out_dtype=tl.float32)
+
+ m_ptrs = (
+ M_out
+ + alpha * stride_m_alpha
+ + offs_m[:, None] * stride_m_tile
+ + offs_n[None, :] * stride_m_k
+ )
+ m_mask = (offs_m[:, None] < T) & (offs_n[None, :] < K_out)
+ tl.store(m_ptrs, acc, mask=m_mask)
+
+
+@triton.autotune(
+ configs=AUTOTUNE_WINO4_OUTPUT_CONFIGS,
+ key=["T", "K_out"],
+ cache_results=True,
+)
+@triton.jit
+def _winograd_f4x3_output_transform_kernel(
+ M_in,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ K_out: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ tile_H: tl.constexpr,
+ tile_W: tl.constexpr,
+ T: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+ LAYOUT: tl.constexpr = 0,
+):
+ # M layout: [36, T, K_out] contiguous
+ stride_m_k: tl.constexpr = 1
+ stride_m_tile: tl.constexpr = K_out
+ stride_m_alpha: tl.constexpr = T * K_out
+ # Y layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC
+ if LAYOUT == 0:
+ stride_y_q: tl.constexpr = 1
+ stride_y_p: tl.constexpr = Q
+ stride_y_k: tl.constexpr = P * Q
+ stride_y_n: tl.constexpr = K_out * P * Q
+ else:
+ stride_y_q: tl.constexpr = K_out
+ stride_y_p: tl.constexpr = Q * K_out
+ stride_y_k: tl.constexpr = 1
+ stride_y_n: tl.constexpr = P * Q * K_out
+
+ tile_idx = tl.program_id(0)
+ k_block = tl.program_id(1)
+
+ n = tile_idx // (tile_H * tile_W)
+ rem = tile_idx % (tile_H * tile_W)
+ th = rem // tile_W
+ tw = rem % tile_W
+
+ p_start = th * 4
+ q_start = tw * 4
+
+ offs_k = k_block * BLOCK_K + tl.arange(0, BLOCK_K)
+ k_mask = offs_k < K_out
+
+ # Load 36 values from M[alpha, tile_idx, k]
+ m_base = M_in + tile_idx * stride_m_tile + offs_k * stride_m_k
+
+ m00 = tl.load(m_base + 0 * stride_m_alpha, mask=k_mask, other=0.0)
+ m01 = tl.load(m_base + 1 * stride_m_alpha, mask=k_mask, other=0.0)
+ m02 = tl.load(m_base + 2 * stride_m_alpha, mask=k_mask, other=0.0)
+ m03 = tl.load(m_base + 3 * stride_m_alpha, mask=k_mask, other=0.0)
+ m04 = tl.load(m_base + 4 * stride_m_alpha, mask=k_mask, other=0.0)
+ m05 = tl.load(m_base + 5 * stride_m_alpha, mask=k_mask, other=0.0)
+ m10 = tl.load(m_base + 6 * stride_m_alpha, mask=k_mask, other=0.0)
+ m11 = tl.load(m_base + 7 * stride_m_alpha, mask=k_mask, other=0.0)
+ m12 = tl.load(m_base + 8 * stride_m_alpha, mask=k_mask, other=0.0)
+ m13 = tl.load(m_base + 9 * stride_m_alpha, mask=k_mask, other=0.0)
+ m14 = tl.load(m_base + 10 * stride_m_alpha, mask=k_mask, other=0.0)
+ m15 = tl.load(m_base + 11 * stride_m_alpha, mask=k_mask, other=0.0)
+ m20 = tl.load(m_base + 12 * stride_m_alpha, mask=k_mask, other=0.0)
+ m21 = tl.load(m_base + 13 * stride_m_alpha, mask=k_mask, other=0.0)
+ m22 = tl.load(m_base + 14 * stride_m_alpha, mask=k_mask, other=0.0)
+ m23 = tl.load(m_base + 15 * stride_m_alpha, mask=k_mask, other=0.0)
+ m24 = tl.load(m_base + 16 * stride_m_alpha, mask=k_mask, other=0.0)
+ m25 = tl.load(m_base + 17 * stride_m_alpha, mask=k_mask, other=0.0)
+ m30 = tl.load(m_base + 18 * stride_m_alpha, mask=k_mask, other=0.0)
+ m31 = tl.load(m_base + 19 * stride_m_alpha, mask=k_mask, other=0.0)
+ m32 = tl.load(m_base + 20 * stride_m_alpha, mask=k_mask, other=0.0)
+ m33 = tl.load(m_base + 21 * stride_m_alpha, mask=k_mask, other=0.0)
+ m34 = tl.load(m_base + 22 * stride_m_alpha, mask=k_mask, other=0.0)
+ m35 = tl.load(m_base + 23 * stride_m_alpha, mask=k_mask, other=0.0)
+ m40 = tl.load(m_base + 24 * stride_m_alpha, mask=k_mask, other=0.0)
+ m41 = tl.load(m_base + 25 * stride_m_alpha, mask=k_mask, other=0.0)
+ m42 = tl.load(m_base + 26 * stride_m_alpha, mask=k_mask, other=0.0)
+ m43 = tl.load(m_base + 27 * stride_m_alpha, mask=k_mask, other=0.0)
+ m44 = tl.load(m_base + 28 * stride_m_alpha, mask=k_mask, other=0.0)
+ m45 = tl.load(m_base + 29 * stride_m_alpha, mask=k_mask, other=0.0)
+ m50 = tl.load(m_base + 30 * stride_m_alpha, mask=k_mask, other=0.0)
+ m51 = tl.load(m_base + 31 * stride_m_alpha, mask=k_mask, other=0.0)
+ m52 = tl.load(m_base + 32 * stride_m_alpha, mask=k_mask, other=0.0)
+ m53 = tl.load(m_base + 33 * stride_m_alpha, mask=k_mask, other=0.0)
+ m54 = tl.load(m_base + 34 * stride_m_alpha, mask=k_mask, other=0.0)
+ m55 = tl.load(m_base + 35 * stride_m_alpha, mask=k_mask, other=0.0)
+
+ # A^T column transform (4x6):
+ # A^T = [[ 1, 1, 1, 1, 1, 0],
+ # [ 0, 1, -1, 2, -2, 0],
+ # [ 0, 1, 1, 4, 4, 0],
+ # [ 0, 1, -1, 8, -8, 1]]
+ # Apply to each column s: s_col = A^T @ m_col
+
+ # Column 0
+ s00 = m00 + m10 + m20 + m30 + m40
+ s10 = m10 - m20 + 2 * m30 - 2 * m40
+ s20 = m10 + m20 + 4 * m30 + 4 * m40
+ s30 = m10 - m20 + 8 * m30 - 8 * m40 + m50
+ # Column 1
+ s01 = m01 + m11 + m21 + m31 + m41
+ s11 = m11 - m21 + 2 * m31 - 2 * m41
+ s21 = m11 + m21 + 4 * m31 + 4 * m41
+ s31 = m11 - m21 + 8 * m31 - 8 * m41 + m51
+ # Column 2
+ s02 = m02 + m12 + m22 + m32 + m42
+ s12 = m12 - m22 + 2 * m32 - 2 * m42
+ s22 = m12 + m22 + 4 * m32 + 4 * m42
+ s32 = m12 - m22 + 8 * m32 - 8 * m42 + m52
+ # Column 3
+ s03 = m03 + m13 + m23 + m33 + m43
+ s13 = m13 - m23 + 2 * m33 - 2 * m43
+ s23 = m13 + m23 + 4 * m33 + 4 * m43
+ s33 = m13 - m23 + 8 * m33 - 8 * m43 + m53
+ # Column 4
+ s04 = m04 + m14 + m24 + m34 + m44
+ s14 = m14 - m24 + 2 * m34 - 2 * m44
+ s24 = m14 + m24 + 4 * m34 + 4 * m44
+ s34 = m14 - m24 + 8 * m34 - 8 * m44 + m54
+ # Column 5
+ s05 = m05 + m15 + m25 + m35 + m45
+ s15 = m15 - m25 + 2 * m35 - 2 * m45
+ s25 = m15 + m25 + 4 * m35 + 4 * m45
+ s35 = m15 - m25 + 8 * m35 - 8 * m45 + m55
+
+ # A^T row transform
+ y00 = s00 + s01 + s02 + s03 + s04
+ y01 = s01 - s02 + 2 * s03 - 2 * s04
+ y02 = s01 + s02 + 4 * s03 + 4 * s04
+ y03 = s01 - s02 + 8 * s03 - 8 * s04 + s05
+
+ y10 = s10 + s11 + s12 + s13 + s14
+ y11 = s11 - s12 + 2 * s13 - 2 * s14
+ y12 = s11 + s12 + 4 * s13 + 4 * s14
+ y13 = s11 - s12 + 8 * s13 - 8 * s14 + s15
+
+ y20 = s20 + s21 + s22 + s23 + s24
+ y21 = s21 - s22 + 2 * s23 - 2 * s24
+ y22 = s21 + s22 + 4 * s23 + 4 * s24
+ y23 = s21 - s22 + 8 * s23 - 8 * s24 + s25
+
+ y30 = s30 + s31 + s32 + s33 + s34
+ y31 = s31 - s32 + 2 * s33 - 2 * s34
+ y32 = s31 + s32 + 4 * s33 + 4 * s34
+ y33 = s31 - s32 + 8 * s33 - 8 * s34 + s35
+
+ # Bias
+ if HAS_BIAS:
+ bias = tl.load(BIAS + offs_k, mask=k_mask, other=0.0)
+ y00 += bias
+ y01 += bias
+ y02 += bias
+ y03 += bias
+ y10 += bias
+ y11 += bias
+ y12 += bias
+ y13 += bias
+ y20 += bias
+ y21 += bias
+ y22 += bias
+ y23 += bias
+ y30 += bias
+ y31 += bias
+ y32 += bias
+ y33 += bias
+
+ # Activation
+ if ACT_TYPE == 1:
+ y00 = tl.maximum(y00, 0)
+ y01 = tl.maximum(y01, 0)
+ y02 = tl.maximum(y02, 0)
+ y03 = tl.maximum(y03, 0)
+ y10 = tl.maximum(y10, 0)
+ y11 = tl.maximum(y11, 0)
+ y12 = tl.maximum(y12, 0)
+ y13 = tl.maximum(y13, 0)
+ y20 = tl.maximum(y20, 0)
+ y21 = tl.maximum(y21, 0)
+ y22 = tl.maximum(y22, 0)
+ y23 = tl.maximum(y23, 0)
+ y30 = tl.maximum(y30, 0)
+ y31 = tl.maximum(y31, 0)
+ y32 = tl.maximum(y32, 0)
+ y33 = tl.maximum(y33, 0)
+ elif ACT_TYPE == 2:
+ y00 = tl.minimum(tl.maximum(y00, 0), 6)
+ y01 = tl.minimum(tl.maximum(y01, 0), 6)
+ y02 = tl.minimum(tl.maximum(y02, 0), 6)
+ y03 = tl.minimum(tl.maximum(y03, 0), 6)
+ y10 = tl.minimum(tl.maximum(y10, 0), 6)
+ y11 = tl.minimum(tl.maximum(y11, 0), 6)
+ y12 = tl.minimum(tl.maximum(y12, 0), 6)
+ y13 = tl.minimum(tl.maximum(y13, 0), 6)
+ y20 = tl.minimum(tl.maximum(y20, 0), 6)
+ y21 = tl.minimum(tl.maximum(y21, 0), 6)
+ y22 = tl.minimum(tl.maximum(y22, 0), 6)
+ y23 = tl.minimum(tl.maximum(y23, 0), 6)
+ y30 = tl.minimum(tl.maximum(y30, 0), 6)
+ y31 = tl.minimum(tl.maximum(y31, 0), 6)
+ y32 = tl.minimum(tl.maximum(y32, 0), 6)
+ y33 = tl.minimum(tl.maximum(y33, 0), 6)
+ elif ACT_TYPE == 3:
+ y00 = (
+ 0.5 * y00 * (1.0 + _tanh(0.7978845608 * (y00 + 0.044715 * y00 * y00 * y00)))
+ )
+ y01 = (
+ 0.5 * y01 * (1.0 + _tanh(0.7978845608 * (y01 + 0.044715 * y01 * y01 * y01)))
+ )
+ y02 = (
+ 0.5 * y02 * (1.0 + _tanh(0.7978845608 * (y02 + 0.044715 * y02 * y02 * y02)))
+ )
+ y03 = (
+ 0.5 * y03 * (1.0 + _tanh(0.7978845608 * (y03 + 0.044715 * y03 * y03 * y03)))
+ )
+ y10 = (
+ 0.5 * y10 * (1.0 + _tanh(0.7978845608 * (y10 + 0.044715 * y10 * y10 * y10)))
+ )
+ y11 = (
+ 0.5 * y11 * (1.0 + _tanh(0.7978845608 * (y11 + 0.044715 * y11 * y11 * y11)))
+ )
+ y12 = (
+ 0.5 * y12 * (1.0 + _tanh(0.7978845608 * (y12 + 0.044715 * y12 * y12 * y12)))
+ )
+ y13 = (
+ 0.5 * y13 * (1.0 + _tanh(0.7978845608 * (y13 + 0.044715 * y13 * y13 * y13)))
+ )
+ y20 = (
+ 0.5 * y20 * (1.0 + _tanh(0.7978845608 * (y20 + 0.044715 * y20 * y20 * y20)))
+ )
+ y21 = (
+ 0.5 * y21 * (1.0 + _tanh(0.7978845608 * (y21 + 0.044715 * y21 * y21 * y21)))
+ )
+ y22 = (
+ 0.5 * y22 * (1.0 + _tanh(0.7978845608 * (y22 + 0.044715 * y22 * y22 * y22)))
+ )
+ y23 = (
+ 0.5 * y23 * (1.0 + _tanh(0.7978845608 * (y23 + 0.044715 * y23 * y23 * y23)))
+ )
+ y30 = (
+ 0.5 * y30 * (1.0 + _tanh(0.7978845608 * (y30 + 0.044715 * y30 * y30 * y30)))
+ )
+ y31 = (
+ 0.5 * y31 * (1.0 + _tanh(0.7978845608 * (y31 + 0.044715 * y31 * y31 * y31)))
+ )
+ y32 = (
+ 0.5 * y32 * (1.0 + _tanh(0.7978845608 * (y32 + 0.044715 * y32 * y32 * y32)))
+ )
+ y33 = (
+ 0.5 * y33 * (1.0 + _tanh(0.7978845608 * (y33 + 0.044715 * y33 * y33 * y33)))
+ )
+
+ # Store 4x4 output tile
+ n_valid = n < N
+ y_base = Y + n * stride_y_n + offs_k * stride_y_k
+
+ if n_valid:
+ for r in tl.static_range(4):
+ p = p_start + r
+ if p < P:
+ for s in tl.static_range(4):
+ q = q_start + s
+ if q < Q:
+ if r == 0:
+ if s == 0:
+ val = y00
+ elif s == 1:
+ val = y01
+ elif s == 2:
+ val = y02
+ else:
+ val = y03
+ elif r == 1:
+ if s == 0:
+ val = y10
+ elif s == 1:
+ val = y11
+ elif s == 2:
+ val = y12
+ else:
+ val = y13
+ elif r == 2:
+ if s == 0:
+ val = y20
+ elif s == 1:
+ val = y21
+ elif s == 2:
+ val = y22
+ else:
+ val = y23
+ else:
+ if s == 0:
+ val = y30
+ elif s == 1:
+ val = y31
+ elif s == 2:
+ val = y32
+ else:
+ val = y33
+ tl.store(
+ y_base + p * stride_y_p + q * stride_y_q, val, mask=k_mask
+ )
+
+
+@triton.autotune(
+ configs=AUTOTUNE_FUSED_F4X3_CONFIGS,
+ key=["T", "K_out", "C_pad"],
+ cache_results=True,
+)
+@triton.jit
+def _winograd_f4x3_fused_gemm_output_kernel(
+ V,
+ U,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ K_out: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ C_pad: tl.constexpr,
+ tile_H: tl.constexpr,
+ tile_W: tl.constexpr,
+ T: tl.constexpr,
+ BLOCK_T: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+ LAYOUT: tl.constexpr = 0,
+):
+ """Fused GEMM + output transform for Winograd F(4x4,3x3).
+ Processes 6 alphas per column (6 columns = 36 total) sequentially,
+ accumulating column-transform results to reduce register pressure."""
+ # V layout: [36, T, C_pad] contiguous
+ stride_v_c: tl.constexpr = 1
+ stride_v_tile: tl.constexpr = C_pad
+ stride_v_alpha: tl.constexpr = T * C_pad
+ # U layout: [36, K_out, C_pad] contiguous
+ stride_u_c: tl.constexpr = 1
+ stride_u_k: tl.constexpr = C_pad
+ stride_u_alpha: tl.constexpr = K_out * C_pad
+ # Y layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC
+ if LAYOUT == 0:
+ stride_y_q: tl.constexpr = 1
+ stride_y_p: tl.constexpr = Q
+ stride_y_k: tl.constexpr = P * Q
+ stride_y_n: tl.constexpr = K_out * P * Q
+ else:
+ stride_y_q: tl.constexpr = K_out
+ stride_y_p: tl.constexpr = Q * K_out
+ stride_y_k: tl.constexpr = 1
+ stride_y_n: tl.constexpr = P * Q * K_out
+
+ pid_t = tl.program_id(0)
+ pid_k = tl.program_id(1)
+
+ offs_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)
+ offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
+ t_mask = offs_t < T
+ k_mask = offs_k < K_out
+
+ offs_c = tl.arange(0, BLOCK_C)
+
+ # Process column by column to reduce register pressure.
+ # For each column s (0..5), accumulate 6 alpha GEMMs (rows 0..5),
+ # then apply A^T column transform (6→4), storing s[row][col] results.
+ # A^T = [[1,1,1,1,1,0],[0,1,-1,2,-2,0],[0,1,1,4,4,0],[0,1,-1,8,-8,1]]
+ # We need s[0..3][0..5] = 24 accumulators after column transform.
+ # Then row transform: s[row][0..5] → y[row][0..3] = 16 output values.
+
+ # Accumulate all 24 s-values (4 rows × 6 columns)
+ s00 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s01 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s02 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s03 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s04 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s05 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s10 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s11 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s12 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s13 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s14 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s15 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s20 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s21 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s22 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s23 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s24 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s25 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s30 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s31 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s32 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s33 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s34 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+ s35 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32)
+
+ v_mask_2d = t_mask[:, None] & (offs_c < C_pad)[None, :]
+ u_mask_2d = k_mask[:, None] & (offs_c < C_pad)[None, :]
+
+ for c0 in range(0, C_pad, BLOCK_C):
+ c_offs = c0 + offs_c
+ if c0 > 0:
+ v_mask_2d = t_mask[:, None] & (c_offs < C_pad)[None, :]
+ u_mask_2d = k_mask[:, None] & (c_offs < C_pad)[None, :]
+
+ # Process all 36 alphas, applying column transform on-the-fly
+ for col in tl.static_range(6):
+ # Load 6 V tiles (rows 0-5 for this column) and 6 U tiles
+ # alpha = row * 6 + col
+ v_base_c = (
+ V + offs_t[:, None] * stride_v_tile + c_offs[None, :] * stride_v_c
+ )
+ u_base_c = U + offs_k[:, None] * stride_u_k + c_offs[None, :] * stride_u_c
+
+ # Load V and U for 6 rows, compute GEMM, apply column transform
+ v0 = tl.load(
+ v_base_c + (0 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u0 = tl.load(
+ u_base_c + (0 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m0 = tl.dot(v0, tl.trans(u0), out_dtype=tl.float32)
+
+ v1 = tl.load(
+ v_base_c + (1 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u1 = tl.load(
+ u_base_c + (1 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m1 = tl.dot(v1, tl.trans(u1), out_dtype=tl.float32)
+
+ v2 = tl.load(
+ v_base_c + (2 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u2 = tl.load(
+ u_base_c + (2 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m2 = tl.dot(v2, tl.trans(u2), out_dtype=tl.float32)
+
+ v3 = tl.load(
+ v_base_c + (3 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u3 = tl.load(
+ u_base_c + (3 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m3 = tl.dot(v3, tl.trans(u3), out_dtype=tl.float32)
+
+ v4 = tl.load(
+ v_base_c + (4 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u4 = tl.load(
+ u_base_c + (4 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m4 = tl.dot(v4, tl.trans(u4), out_dtype=tl.float32)
+
+ v5 = tl.load(
+ v_base_c + (5 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0
+ )
+ u5 = tl.load(
+ u_base_c + (5 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0
+ )
+ m5 = tl.dot(v5, tl.trans(u5), out_dtype=tl.float32)
+
+ # Column transform: A^T @ [m0..m5] → [s0, s1, s2, s3]
+ sc0 = m0 + m1 + m2 + m3 + m4
+ sc1 = m1 - m2 + 2 * m3 - 2 * m4
+ sc2 = m1 + m2 + 4 * m3 + 4 * m4
+ sc3 = m1 - m2 + 8 * m3 - 8 * m4 + m5
+
+ # Accumulate into the appropriate column
+ if col == 0:
+ s00 += sc0
+ s10 += sc1
+ s20 += sc2
+ s30 += sc3
+ elif col == 1:
+ s01 += sc0
+ s11 += sc1
+ s21 += sc2
+ s31 += sc3
+ elif col == 2:
+ s02 += sc0
+ s12 += sc1
+ s22 += sc2
+ s32 += sc3
+ elif col == 3:
+ s03 += sc0
+ s13 += sc1
+ s23 += sc2
+ s33 += sc3
+ elif col == 4:
+ s04 += sc0
+ s14 += sc1
+ s24 += sc2
+ s34 += sc3
+ elif col == 5:
+ s05 += sc0
+ s15 += sc1
+ s25 += sc2
+ s35 += sc3
+
+ # Row transform: A^T applied to rows
+ y00 = s00 + s01 + s02 + s03 + s04
+ y01 = s01 - s02 + 2 * s03 - 2 * s04
+ y02 = s01 + s02 + 4 * s03 + 4 * s04
+ y03 = s01 - s02 + 8 * s03 - 8 * s04 + s05
+
+ y10 = s10 + s11 + s12 + s13 + s14
+ y11 = s11 - s12 + 2 * s13 - 2 * s14
+ y12 = s11 + s12 + 4 * s13 + 4 * s14
+ y13 = s11 - s12 + 8 * s13 - 8 * s14 + s15
+
+ y20 = s20 + s21 + s22 + s23 + s24
+ y21 = s21 - s22 + 2 * s23 - 2 * s24
+ y22 = s21 + s22 + 4 * s23 + 4 * s24
+ y23 = s21 - s22 + 8 * s23 - 8 * s24 + s25
+
+ y30 = s30 + s31 + s32 + s33 + s34
+ y31 = s31 - s32 + 2 * s33 - 2 * s34
+ y32 = s31 + s32 + 4 * s33 + 4 * s34
+ y33 = s31 - s32 + 8 * s33 - 8 * s34 + s35
+
+ # Bias
+ if HAS_BIAS:
+ bias = tl.load(BIAS + offs_k, mask=k_mask, other=0.0)
+ b = bias[None, :]
+ y00 += b
+ y01 += b
+ y02 += b
+ y03 += b
+ y10 += b
+ y11 += b
+ y12 += b
+ y13 += b
+ y20 += b
+ y21 += b
+ y22 += b
+ y23 += b
+ y30 += b
+ y31 += b
+ y32 += b
+ y33 += b
+
+ # Activation
+ if ACT_TYPE == 1:
+ y00 = tl.maximum(y00, 0)
+ y01 = tl.maximum(y01, 0)
+ y02 = tl.maximum(y02, 0)
+ y03 = tl.maximum(y03, 0)
+ y10 = tl.maximum(y10, 0)
+ y11 = tl.maximum(y11, 0)
+ y12 = tl.maximum(y12, 0)
+ y13 = tl.maximum(y13, 0)
+ y20 = tl.maximum(y20, 0)
+ y21 = tl.maximum(y21, 0)
+ y22 = tl.maximum(y22, 0)
+ y23 = tl.maximum(y23, 0)
+ y30 = tl.maximum(y30, 0)
+ y31 = tl.maximum(y31, 0)
+ y32 = tl.maximum(y32, 0)
+ y33 = tl.maximum(y33, 0)
+ elif ACT_TYPE == 2:
+ y00 = tl.minimum(tl.maximum(y00, 0), 6)
+ y01 = tl.minimum(tl.maximum(y01, 0), 6)
+ y02 = tl.minimum(tl.maximum(y02, 0), 6)
+ y03 = tl.minimum(tl.maximum(y03, 0), 6)
+ y10 = tl.minimum(tl.maximum(y10, 0), 6)
+ y11 = tl.minimum(tl.maximum(y11, 0), 6)
+ y12 = tl.minimum(tl.maximum(y12, 0), 6)
+ y13 = tl.minimum(tl.maximum(y13, 0), 6)
+ y20 = tl.minimum(tl.maximum(y20, 0), 6)
+ y21 = tl.minimum(tl.maximum(y21, 0), 6)
+ y22 = tl.minimum(tl.maximum(y22, 0), 6)
+ y23 = tl.minimum(tl.maximum(y23, 0), 6)
+ y30 = tl.minimum(tl.maximum(y30, 0), 6)
+ y31 = tl.minimum(tl.maximum(y31, 0), 6)
+ y32 = tl.minimum(tl.maximum(y32, 0), 6)
+ y33 = tl.minimum(tl.maximum(y33, 0), 6)
+ elif ACT_TYPE == 3:
+ y00 = (
+ 0.5 * y00 * (1.0 + _tanh(0.7978845608 * (y00 + 0.044715 * y00 * y00 * y00)))
+ )
+ y01 = (
+ 0.5 * y01 * (1.0 + _tanh(0.7978845608 * (y01 + 0.044715 * y01 * y01 * y01)))
+ )
+ y02 = (
+ 0.5 * y02 * (1.0 + _tanh(0.7978845608 * (y02 + 0.044715 * y02 * y02 * y02)))
+ )
+ y03 = (
+ 0.5 * y03 * (1.0 + _tanh(0.7978845608 * (y03 + 0.044715 * y03 * y03 * y03)))
+ )
+ y10 = (
+ 0.5 * y10 * (1.0 + _tanh(0.7978845608 * (y10 + 0.044715 * y10 * y10 * y10)))
+ )
+ y11 = (
+ 0.5 * y11 * (1.0 + _tanh(0.7978845608 * (y11 + 0.044715 * y11 * y11 * y11)))
+ )
+ y12 = (
+ 0.5 * y12 * (1.0 + _tanh(0.7978845608 * (y12 + 0.044715 * y12 * y12 * y12)))
+ )
+ y13 = (
+ 0.5 * y13 * (1.0 + _tanh(0.7978845608 * (y13 + 0.044715 * y13 * y13 * y13)))
+ )
+ y20 = (
+ 0.5 * y20 * (1.0 + _tanh(0.7978845608 * (y20 + 0.044715 * y20 * y20 * y20)))
+ )
+ y21 = (
+ 0.5 * y21 * (1.0 + _tanh(0.7978845608 * (y21 + 0.044715 * y21 * y21 * y21)))
+ )
+ y22 = (
+ 0.5 * y22 * (1.0 + _tanh(0.7978845608 * (y22 + 0.044715 * y22 * y22 * y22)))
+ )
+ y23 = (
+ 0.5 * y23 * (1.0 + _tanh(0.7978845608 * (y23 + 0.044715 * y23 * y23 * y23)))
+ )
+ y30 = (
+ 0.5 * y30 * (1.0 + _tanh(0.7978845608 * (y30 + 0.044715 * y30 * y30 * y30)))
+ )
+ y31 = (
+ 0.5 * y31 * (1.0 + _tanh(0.7978845608 * (y31 + 0.044715 * y31 * y31 * y31)))
+ )
+ y32 = (
+ 0.5 * y32 * (1.0 + _tanh(0.7978845608 * (y32 + 0.044715 * y32 * y32 * y32)))
+ )
+ y33 = (
+ 0.5 * y33 * (1.0 + _tanh(0.7978845608 * (y33 + 0.044715 * y33 * y33 * y33)))
+ )
+
+ # Decode tile indices and compute per-tile base pointers
+ n_idx = offs_t // (tile_H * tile_W)
+ rem = offs_t % (tile_H * tile_W)
+ th = rem // tile_W
+ tw = rem % tile_W
+ p_start = th * 4
+ q_start = tw * 4
+
+ y_base = Y + n_idx * stride_y_n
+ y_ptrs_base = y_base[:, None] + offs_k[None, :] * stride_y_k
+ full_mask = t_mask[:, None] & k_mask[None, :]
+
+ # Store 4x4 output tile using 2D scatter stores
+ for di in tl.static_range(4):
+ for dj in tl.static_range(4):
+ mask_ij = (
+ full_mask & (p_start + di < P)[:, None] & (q_start + dj < Q)[:, None]
+ )
+ ptrs_ij = (
+ y_ptrs_base
+ + (p_start[:, None] + di) * stride_y_p
+ + (q_start[:, None] + dj) * stride_y_q
+ )
+ if di == 0:
+ if dj == 0:
+ tl.store(ptrs_ij, y00, mask=mask_ij)
+ elif dj == 1:
+ tl.store(ptrs_ij, y01, mask=mask_ij)
+ elif dj == 2:
+ tl.store(ptrs_ij, y02, mask=mask_ij)
+ else:
+ tl.store(ptrs_ij, y03, mask=mask_ij)
+ elif di == 1:
+ if dj == 0:
+ tl.store(ptrs_ij, y10, mask=mask_ij)
+ elif dj == 1:
+ tl.store(ptrs_ij, y11, mask=mask_ij)
+ elif dj == 2:
+ tl.store(ptrs_ij, y12, mask=mask_ij)
+ else:
+ tl.store(ptrs_ij, y13, mask=mask_ij)
+ elif di == 2:
+ if dj == 0:
+ tl.store(ptrs_ij, y20, mask=mask_ij)
+ elif dj == 1:
+ tl.store(ptrs_ij, y21, mask=mask_ij)
+ elif dj == 2:
+ tl.store(ptrs_ij, y22, mask=mask_ij)
+ else:
+ tl.store(ptrs_ij, y23, mask=mask_ij)
+ else:
+ if dj == 0:
+ tl.store(ptrs_ij, y30, mask=mask_ij)
+ elif dj == 1:
+ tl.store(ptrs_ij, y31, mask=mask_ij)
+ elif dj == 2:
+ tl.store(ptrs_ij, y32, mask=mask_ij)
+ else:
+ tl.store(ptrs_ij, y33, mask=mask_ij)
diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_general.py b/aiter/ops/triton/_triton_kernels/conv/conv_general.py
new file mode 100644
index 0000000000..9c5ade7fd6
--- /dev/null
+++ b/aiter/ops/triton/_triton_kernels/conv/conv_general.py
@@ -0,0 +1,163 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import triton
+import triton.language as tl
+from .helpers import _tanh, AUTOTUNE_GENERAL_CONFIGS
+
+
+@triton.autotune(
+ configs=AUTOTUNE_GENERAL_CONFIGS,
+ key=["M_total", "K_out", "K_pad"],
+ reset_to_zero=["Y"],
+ warmup=50,
+ rep=200,
+ cache_results=True,
+)
+@triton.jit
+def _conv2d_general_kernel(
+ X,
+ WK,
+ BIAS,
+ Y,
+ N: tl.constexpr,
+ C: tl.constexpr,
+ H: tl.constexpr,
+ W_in: tl.constexpr,
+ K_out: tl.constexpr,
+ R: tl.constexpr,
+ S: tl.constexpr,
+ P: tl.constexpr,
+ Q: tl.constexpr,
+ K_pad: tl.constexpr,
+ stride_h: tl.constexpr,
+ stride_w: tl.constexpr,
+ pad_h: tl.constexpr,
+ pad_w: tl.constexpr,
+ dil_h: tl.constexpr,
+ dil_w: tl.constexpr,
+ M_total: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ ACT_TYPE: tl.constexpr,
+ LAYOUT: tl.constexpr,
+):
+ """General conv kernel with precomputed bases.
+ LAYOUT: 0=NCHW, 1=NHWC
+ """
+ # WK is always [K_out, K_pad] contiguous
+ stride_wk_kout: tl.constexpr = K_pad
+ stride_wk_kred: tl.constexpr = 1
+ if LAYOUT == 0:
+ # NCHW: X[N, C, H, W_in], Y[N, K_out, P, Q]
+ stride_x_n: tl.constexpr = C * H * W_in
+ stride_x_c: tl.constexpr = H * W_in
+ stride_x_h: tl.constexpr = W_in
+ stride_x_w: tl.constexpr = 1
+ stride_y_n: tl.constexpr = K_out * P * Q
+ stride_y_k: tl.constexpr = P * Q
+ stride_y_p: tl.constexpr = Q
+ stride_y_q: tl.constexpr = 1
+ else:
+ # NHWC: X[N, H, W_in, C], Y[N, P, Q, K_out]
+ stride_x_n: tl.constexpr = H * W_in * C
+ stride_x_c: tl.constexpr = 1
+ stride_x_h: tl.constexpr = W_in * C
+ stride_x_w: tl.constexpr = C
+ stride_y_n: tl.constexpr = P * Q * K_out
+ stride_y_k: tl.constexpr = 1
+ stride_y_p: tl.constexpr = Q * K_out
+ stride_y_q: tl.constexpr = K_out
+
+ pid = tl.program_id(axis=0)
+
+ num_pid_m = tl.cdiv(M_total, BLOCK_M)
+ num_pid_n = tl.cdiv(K_out, BLOCK_N)
+
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m >= num_pid_m:
+ return
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ kout_mask = offs_n < K_out
+
+ # Decode offs_m -> (n_idx, p_idx, q_idx)
+ n_idx = offs_m[:, None] // (P * Q)
+ pq = offs_m[:, None] % (P * Q)
+ p_idx = pq // Q
+ q_idx = pq % Q
+
+ n_valid = n_idx < N
+
+ # Precompute base positions
+ base_oh = p_idx * stride_h - pad_h
+ base_ow = q_idx * stride_w - pad_w
+ stride_dh = dil_h * stride_x_h
+ stride_dw = dil_w * stride_x_w
+ x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w
+ wk_base = WK + offs_n[None, :] * stride_wk_kout
+
+ Y_ptrs = (
+ Y
+ + n_idx * stride_y_n
+ + offs_n[None, :] * stride_y_k
+ + p_idx * stride_y_p
+ + q_idx * stride_y_q
+ )
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ offs_k = tl.arange(0, BLOCK_K)
+ rs_stride = R * S
+
+ for k0 in range(0, K_pad, BLOCK_K):
+ kred = k0 + offs_k
+
+ WK_ptrs = wk_base + kred[:, None] * stride_wk_kred
+ w_tile = tl.load(WK_ptrs, mask=kout_mask[None, :], other=0.0)
+
+ c = kred // rs_stride
+ rs = kred % rs_stride
+ r = rs // S
+ s = rs % S
+
+ oh = base_oh + r * dil_h
+ ow = base_ow + s * dil_w
+
+ X_ptrs = x_base + c * stride_x_c + r * stride_dh + s * stride_dw
+ x_mask = (
+ n_valid & (oh >= 0) & (ow >= 0) & (oh < H) & (ow < W_in) & (c[None, :] < C)
+ )
+ x_tile = tl.load(X_ptrs, mask=x_mask, other=0.0)
+
+ acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32)
+
+ if HAS_BIAS:
+ b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0)
+ acc += b[None, :]
+
+ if ACT_TYPE == 1:
+ acc = tl.maximum(acc, 0)
+ elif ACT_TYPE == 2:
+ acc = tl.minimum(tl.maximum(acc, 0), 6)
+ elif ACT_TYPE == 3:
+ acc = (
+ 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc)))
+ )
+
+ tl.store(
+ Y_ptrs,
+ acc,
+ mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]),
+ )
diff --git a/aiter/ops/triton/_triton_kernels/conv/helpers.py b/aiter/ops/triton/_triton_kernels/conv/helpers.py
new file mode 100644
index 0000000000..e80f81e052
--- /dev/null
+++ b/aiter/ops/triton/_triton_kernels/conv/helpers.py
@@ -0,0 +1,188 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _tanh(x):
+ x = tl.minimum(tl.maximum(x, -10.0), 10.0)
+ e2x = tl.exp(2 * x)
+ return (e2x - 1) / (e2x + 1)
+
+
+# ========================================================================
+# AUTOTUNE CONFIGS — centralized to avoid duplication and cross-imports
+# ========================================================================
+
+# -- 1x1 kernel --
+AUTOTUNE_1x1_CONFIGS = [
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+]
+
+# -- General conv kernel --
+AUTOTUNE_GENERAL_CONFIGS = [
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+]
+
+# -- 3x3 NHWC kernel --
+AUTOTUNE_3x3_NHWC_CONFIGS = [
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 32, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+]
+
+# -- 3x3 channel-blocked kernel --
+AUTOTUNE_3x3_CBLOCKED_CONFIGS = [
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 128, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+]
+
+# -- Winograd F(4,3) GEMM --
+AUTOTUNE_WINO_GEMM_CONFIGS = [
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=8,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8},
+ num_warps=4,
+ num_stages=1,
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4},
+ num_warps=4,
+ num_stages=1,
+ ),
+]
+
+# -- Winograd F(4,3) input transform --
+AUTOTUNE_WINO4_INPUT_CONFIGS = [
+ triton.Config({"BLOCK_C": 64}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_C": 32}, num_warps=4, num_stages=1),
+]
+
+# -- Winograd F(4,3) output transform --
+AUTOTUNE_WINO4_OUTPUT_CONFIGS = [
+ triton.Config({"BLOCK_K": 64}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_K": 128}, num_warps=4, num_stages=1),
+]
+
+# -- Winograd F(4,3) fused GEMM+output --
+AUTOTUNE_FUSED_F4X3_CONFIGS = [
+ triton.Config(
+ {"BLOCK_T": 16, "BLOCK_K": 64, "BLOCK_C": 64}, num_warps=4, num_stages=1
+ ),
+ triton.Config(
+ {"BLOCK_T": 16, "BLOCK_K": 128, "BLOCK_C": 64}, num_warps=8, num_stages=1
+ ),
+ triton.Config(
+ {"BLOCK_T": 32, "BLOCK_K": 64, "BLOCK_C": 64}, num_warps=4, num_stages=1
+ ),
+]
diff --git a/aiter/ops/triton/conv/DESIGN.md b/aiter/ops/triton/conv/DESIGN.md
new file mode 100644
index 0000000000..cdab45ede4
--- /dev/null
+++ b/aiter/ops/triton/conv/DESIGN.md
@@ -0,0 +1,707 @@
+# Design
+
+This document explains how `aiter.ops.triton.conv` is structured, what each kernel does
+and why, the math behind the Winograd path, the heuristic that picks between
+methods, the memory layouts and repacking that connect them, the numerical
+tolerance model used by the test suite, and how to add a new kernel.
+
+> **Audience.** Anyone planning to extend the library, debug a performance
+> regression, or evaluate whether to take a dependency on it.
+
+---
+
+## 1. Goals and non-goals
+
+**In scope.**
+
+- Forward 2-D convolution on AMD ROCm/RDNA (WMMA-capable).
+- `fp16` and `bf16` inputs; configurable output dtype.
+- NCHW and NHWC input layouts.
+- Optional fused bias add and activation (`relu`, `relu6`, `gelu`).
+- Drop-in replacement for `nn.Conv2d` for inference.
+
+**Out of scope (today).**
+
+- **Backward / training.** Inference only. No grad support.
+- **`groups > 1`** (depthwise / grouped). Detected at the example layer and
+ left to PyTorch/MIOpen.
+- **`padding_mode != "zeros"`** (reflect / replicate / circular). Same fall-back.
+- **fp32 / fp8 inputs not supported.**
+- Tuning, autotune configs, and the `_select_3x3_method`
+ routing table are RDNA4-specific for now.
+
+---
+
+## 2. Architecture at a glance
+
+```mermaid
+flowchart TD
+ A([conv2d - user entry]):::entry
+
+ A --> NCHW[layout = nchw
conv2d_nchw]:::router
+ A --> NHWC[layout = nhwc
conv2d_nhwc]:::router
+
+ NCHW --> N1[1x1]:::shape
+ NCHW --> N3[3x3]:::shape
+ NCHW --> NG[other]:::shape
+
+ NHWC1[1x1]:::shape
+ NHWC3[3x3]:::shape
+ NHWCG[other]:::shape
+ NHWC --> NHWC1
+ NHWC --> NHWC3
+ NHWC --> NHWCG
+
+ N1 --> K1[/_conv2d_1x1/]:::kernel
+ N3 --> S1{_select_3x3_method}:::sel
+ NG --> KG[/_conv2d_general/]:::kernel
+
+ NHWC1 --> K1n[/_conv2d_1x1
NHWC/]:::kernel
+ NHWC3 --> S2{_select_3x3_method}:::sel
+ NHWCG --> KGn[/_conv2d_general/]:::kernel
+
+ S1 --> P1[repack
NCHW → NCHWc]:::pack
+ S1 --> KWF[/winograd_f4x3/]:::wino
+ S1 --> P3[repack
NCHW → NCHWc]:::pack
+
+ P1 --> KCB[/cblocked/]:::kernel
+ P3 --> KWFC[/winograd_f4x3_cblocked/]:::wino
+
+ S2 --> KNH[/nhwc_3x3/]:::kernel
+ S2 --> KWFN[/winograd_f4x3
NHWC/]:::wino
+
+ classDef entry fill:#1e3a5f,stroke:#4a90e2,stroke-width:2px,color:#fff
+ classDef router fill:#2c5282,stroke:#63b3ed,color:#fff
+ classDef shape fill:#4a5568,stroke:#a0aec0,color:#fff
+ classDef sel fill:#744210,stroke:#f6ad55,color:#fff
+ classDef pack fill:#7b341e,stroke:#fc8181,color:#fff
+ classDef kernel fill:#22543d,stroke:#68d391,color:#fff
+ classDef wino fill:#553c9a,stroke:#b794f4,color:#fff
+```
+
+The Winograd kernels (`winograd_f4x3*`) themselves are a 3-stage pipeline:
+
+```mermaid
+flowchart LR
+ X([X
NCHW or NHWC]):::data
+ X --> IT[input transform
V = BᵀXB]:::stage
+ IT --> V([V : 36 × T × C_pad]):::data
+
+ V --> GE[batched GEMM
36 of T × K_out
reducing over C_pad]:::stage
+ GE --> M([M : 36 × T × K_out]):::data
+ M --> OT[output transform
Y = AᵀMA]:::stage
+ OT --> Y([Y
NCHW or NHWC]):::data
+
+ V -.fused path.-> FU[fused GEMM + output
V → Y, skips M]:::stage
+ FU -.-> Y
+
+ classDef data fill:#1e3a5f,stroke:#4a90e2,color:#fff
+ classDef stage fill:#553c9a,stroke:#b794f4,color:#fff
+```
+
+The fused kernel (`_winograd_f4x3_fused_gemm_output_kernel`) reads `V` and the
+transformed filter `U` directly, accumulates the column transform's 24 partial
+outputs in registers across the channel-tile loop, applies the final row
+transform once at the end, and writes `Y` without ever materializing the
+intermediate `M[36, T, K_out]` tensor.
+
+**The fused path is opt-in, not router-selected.** `_select_3x3_method` only
+ever returns `winograd_f4x3` or `winograd_f4x3_cblocked` — both the standard
+3-kernel pipeline. The fused variant is reachable explicitly via
+`--method winograd_f4x3_fused` in the bench tool (`op_tests/op_benchmarks/triton/bench_conv2d.py`)
+or by importing `conv2d_winograd_f4x3_fused` directly. It wins on a narrow
+band of small/medium VAE shapes but hits register pressure at the larger
+ResNet 3×3 tile shapes, so it isn't worth a special-case branch in the
+production router.
+
+### Layered Python modules
+
+```
+aiter/ops/triton/conv/
+ __init__.py empty marker (consumers import directly from conv2d.py)
+ conv2d.py public functions + smart routing in conv2d_nchw / conv2d_nhwc
+ _launch.py grid setup, _select_3x3_method, dtype mapping
+ _prepack.py weight/input repacks + LRU caches
+ _utils.py shape math, _is_*_conv predicates, tolerance model, activation
+
+aiter/ops/triton/_triton_kernels/conv/
+ __init__.py empty marker (kernels are imported by full path)
+ helpers.py shared autotune config lists (consumed by @triton.autotune
+ in each kernel file) + _tanh helper
+ conv_1x1.py 1×1 GEMM kernel (NCHW + NHWC via LAYOUT constexpr)
+ conv_3x3.py 3×3 NHWC kernel + 3×3 cblocked (NCHWc) kernel
+ conv_general.py K-major reduction with on-the-fly (c, r, s) decoding
+ conv_3x3_winograd_f4x3.py
+ 5 kernels: input transform (NCHW & NHWC), cblocked input
+ transform, batched GEMM, output transform, fused GEMM+output
+```
+
+The split mirrors *responsibility*, not LOC: `conv2d.py` is "what does the user
+ask for", `_launch.py` is "how do we set up the grid", `_prepack.py` is "what
+shape does the kernel actually want to read", and the `_triton_kernels/conv/` folder is
+"the math". `_utils.py` sits underneath all four as a shared helper layer
+(shape math, eligibility predicates, tolerance model, activation enum).
+
+---
+
+## 3. Public API
+
+The headline is **`conv2d`**:
+
+```python
+def conv2d(x, w_oihw, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1),
+ activation="none", out_dtype=None, layout="nchw"):
+```
+
+| Argument | Meaning |
+|---|---|
+| `x` | Input. NCHW shape, optionally with `channels_last` strides for NHWC mode. |
+| `w_oihw` | Weight in PyTorch's canonical `[K_out, C, R, S]` layout. |
+| `bias` | Optional 1-D bias of length `K_out`, cast to fp32 once at entry. |
+| `stride`, `padding`, `dilation` | Standard `Conv2d` semantics; tuples of ints. |
+| `activation` | `"none" / "relu" / "relu6" / "gelu"` — fused into the kernel epilogue. |
+| `out_dtype` | `None` (default — match input dtype, mirrors `nn.Conv2d`) or one of `torch.float16` / `torch.bfloat16` to override. The fp32 accumulator is downcast at store. |
+| `layout` | `"nchw"` or `"nhwc"` (case-insensitive — passed through `.lower()`). Selects which top-level routing function runs. |
+
+The semi-public method-specific functions (`conv2d_nchw_cblocked`,
+`conv2d_winograd_f4x3_cblocked`, …) take an internal `block_k=64` channel-pack
+tile size used by the prepack caches. It is intentionally not surfaced on the
+public `conv2d` — every autotune config in `helpers.py` assumes 64, so the
+parameter has no good user story.
+
+Both `x.dtype` and (when explicitly passed) `out_dtype` are validated at
+entry: anything other than `torch.float16` / `torch.bfloat16` raises
+`ValueError`. `w_oihw.dtype` is trusted to the kernel.
+
+Everything else (`conv2d_nchw`, `conv2d_1x1`, `conv2d_winograd_f4x3_cblocked`,
+…) is **semi-public**: not in `__all__` but importable by name. The benchmark
+harness uses these directly to compare methods on the same shape.
+
+`conv2d.py` also exposes `_last_triton_kernel: Optional[str]` — set by
+`conv2d_nchw` / `conv2d_nhwc` after each routing decision, read by the bench
+to label rows in the per-layer table.
+
+---
+
+## 4. Smart routing (`_select_3x3_method`)
+
+Only 3×3 has a real choice — 1×1 always uses the 1×1 kernel, 5×5/7×7 always
+fall to `general`. For 3×3, `_launch.py:_select_3x3_method` decides:
+
+```python
+def _select_3x3_method(N, C, H, W, K_out, stride, dilation):
+ # 1) Non-Winograd-eligible (stride>1, dilation>1, or C<4) -> cblocked
+ if not _is_winograd_eligible(3, 3, stride, dilation, C):
+ return "cblocked"
+ # 2) Tile count: F(4,3) emits one 4x4 output tile from each 6x6 input patch,
+ # overlap = 2 (input tiles step by 4 with 6-wide windows).
+ P, Q = _out_hw(H, W, 3, 3, stride, (1,1), dilation)
+ tile_H, tile_W = (P + 3) // 4, (Q + 3) // 4
+ T = N * tile_H * tile_W
+ # 3) Winograd only wins when both C and K are large enough
+ # AND there are enough tiles to amortize transform overhead.
+ if C >= 512 and K_out >= 512 and T >= 98:
+ return "winograd_f4x3_cblocked" if T >= 392 else "winograd_f4x3"
+ return "cblocked"
+```
+
+The thresholds (`C/K ≥ 512`, `T ≥ 98`, `T ≥ 392`) come from a sweep on
+RDNA4 — see the comment block in `_launch.py`. Two implications worth knowing:
+
+1. **The router uses padding `(1,1)` regardless of the caller's padding.** The
+ tile count it computes is approximate by design: the router only needs to
+ pick a method, not the exact `(P, Q)`. The actual padding flows through the
+ chosen kernel unchanged. (This is intentional — see
+ `memory/project_select_3x3_padding_heuristic.md`.)
+2. **Below `C/K = 512`, the cblocked direct kernel is faster than Winograd.**
+ The Winograd transform overhead (read 6×6 patches, two 6×6 matrix
+ multiplies by Bᵀ and B, ~70 fp32 FMAs per (tile, channel) for the input
+ transform alone) dominates the FMA savings until the GEMM body is large
+ enough.
+
+Stride > 1 or dilation > 1 disqualifies Winograd entirely — the F(4,3)
+algorithm is built around a 6-wide window stepping by 4, and that geometry
+breaks under non-unit stride.
+
+NHWC reuses the same router but collapses its output to two destinations:
+the standard NHWC 3×3 kernel, or the **non-cblocked** Winograd path with its
+input transform kernel switched to NHWC reading via the `LAYOUT=1` constexpr.
+Even when the router returns `winograd_f4x3_cblocked`, NHWC mode falls back to
+plain `conv2d_winograd_f4x3` because there is no NHWC cblocked input repack —
+cblocked is an NCHW-only optimization. The GEMM and output transform are
+layout-independent because they operate on `V[36, T, C_pad]` and emit directly
+into the user's chosen output layout.
+
+---
+
+## 5. Per-kernel deep-dive
+
+### 5.0 A platform note on `num_stages`
+
+Every autotune config in `helpers.py` pins **`num_stages=1`**. That is
+deliberate.
+
+The `num_stages > 1` Triton knob is meant to lower to a software-pipelined
+loop: the compiler hoists global → LDS loads for iteration *i+1* across the
+matmul of iteration *i* and rotates through `num_stages` LDS buffers so
+memory and compute overlap. On RDNA there is no asynchronous global-to-LDS
+copy instruction for the compiler to schedule ahead — global loads are
+issued and waited on in-order through `s_waitcnt`. Triton's RDNA backend
+therefore does not currently produce a pipelined loop for `num_stages > 1`;
+the extra LDS buffers are allocated but the load hoisting never
+materializes. Empirically `num_stages > 1` on RDNA either matches or loses
+to `num_stages=1` (more LDS pressure, no overlap gained), so we don't sweep
+it.
+
+This is why every tile choice in the sections below is paired with
+`num_warps` only — there is no stage axis to tune. Memory-latency hiding on
+RDNA comes from the hardware: wave-level scheduling across CUs and
+interleaving arithmetic between a load's issue and its `s_waitcnt`, rather
+than from compiler-issued software pipelining.
+
+### 5.1 `_conv2d_1x1_kernel`
+
+A 1×1 conv is mathematically a GEMM:
+
+```
+y[n,k,p,q] = Σ_c x[n,c,p,q] · w[k,c] (stride/padding still apply to (p,q))
+```
+
+The kernel fuses this with the index unwrap `m → (n, p, q)` and a
+`LAYOUT` constexpr (0 = NCHW, 1 = NHWC) that selects the strides on read/write.
+Highlights:
+
+- **Tile shape:** `BLOCK_M × BLOCK_N × BLOCK_K`, autotuned over the
+ `AUTOTUNE_1x1_CONFIGS` grid in `helpers.py`. On RDNA4 the winners cluster
+ around `BM=128, BN=128`, `BK ∈ {32, 64}`, 8 warps (`num_stages=1` —
+ see 5.0).
+- **L2 cache swizzle.** Tiles are reordered into super-groups of
+ `GROUP_SIZE_M` along the `M` axis so each weight (`N`-axis) tile is reused
+ across `GROUP_SIZE_M` consecutive workgroups before moving on — the same
+ weight columns stay hot in L2 for the duration of a group. Standard MM
+ trick.
+- **Stride > 1 / non-zero padding** are handled by mapping the tile's
+ `(p, q)` back to `(h, w) = (sh·p − ph, sw·q − pw)` and checking bounds. The
+ hot path (stride 1, padding 0) and the slow path go through the same code
+ but the mask collapses to "true" for the hot path so the compiler should
+ hoist it.
+
+### 5.2 `_conv2d_3x3_nhwc_kernel`
+
+NHWC-native 3×3 with **K-major weight layout** `W3[K_out, 9, C_pad]`:
+
+- For each `(p, q)` output column the kernel walks the 9 spatial taps, loads
+ the corresponding `[BLOCK_M, BLOCK_C]` slice of input and the matching
+ `[BLOCK_C, BLOCK_N]` slice of W3, and accumulates
+ `tl.dot(...)` into `[BLOCK_M, BLOCK_N]`.
+- Channel padding `C_pad` is rounded up to a multiple of `BLOCK_C` and the
+ trailing weight lanes are pre-zeroed in the prepack. The spatial validity
+ mask is hoisted out of the inner C loop; the only per-iteration mask is
+ the channel-bound check (`c_offs < C`).
+- Hardcoded `stride_x_c=1` (channels are contiguous in NHWC) means the
+ load addressing collapses to a single base + linear stride; the compiler
+ emits one vectorized load per row.
+- **Same L2 swizzle as 5.1**: workgroups are reordered into super-groups of
+ `GROUP_SIZE_M` along the `M` axis (autotuned to 4 or 8 in
+ `AUTOTUNE_3x3_NHWC_CONFIGS`) so each weight (`N`-axis) tile stays hot in
+ L2 across `GROUP_SIZE_M` consecutive workgroups.
+
+### 5.3 `_conv2d_3x3_cblocked_kernel`
+
+Same 3×3 math and same `M = N·P·Q` GEMM decomposition as the NHWC kernel,
+but with **NCHWc input layout** `X[N, C_blocks, H, W_in, Cb]` (`Cb=64` by
+default; the kernel takes it as a constexpr). Why two kernels?
+
+- The NHWC kernel keeps the user's input layout intact (no input repack)
+ and gets coalesced channel loads "for free" because channels are already
+ the inner contiguous axis.
+- The user's NCHW input has channels on the *outer* stride (`H·W` apart),
+ so a direct read from NCHW would scatter the channel loads. cblocked
+ **repacks the input once** into `[N, C_blocks, H, W, Cb]` to restore a
+ small inner contiguous channel axis (`Cb` wide) that the same `M = N·P·Q`
+ GEMM walk can read coalesced. Output is plain NCHW so there's no output
+ repack.
+
+Both kernels share the GEMM order; the difference is purely in the input
+addressing.
+
+Highlights:
+
+- **`stride_c_local=1` is hardcoded** in the load path (`c_local[None, :]`
+ is added directly to the pointer without a stride multiplier), so the
+ compiler emits one vectorized load per channel chunk — structurally the
+ same hardcoded-stride trick as `stride_x_c=1` in the NHWC kernel.
+- **Channel addressing.** Inside the C loop, `c_offs // Cb` selects the
+ block index and `c_offs % Cb` the offset within the block. This stays
+ coalesced **only when `BLOCK_C ≤ Cb`** — at the boundary, `cblock_idx`
+ jumps and the load address discontinues. This is an implicit constraint
+ on `AUTOTUNE_3x3_CBLOCKED_CONFIGS`; new configs must respect it.
+- **Same L2 swizzle as 5.1/5.2** — workgroups are reordered into
+ super-groups of `GROUP_SIZE_M` along the `M` axis so each weight
+ (`N`-axis) tile stays hot in L2 across `GROUP_SIZE_M` consecutive
+ workgroups.
+
+The cblocked kernel is the default for NCHW 3×3 below the Winograd threshold
+because, even after paying the input repack cost, it consistently beats both
+"general with K-major" and the NHWC kernel with implicit transposes. The
+kernel+repack number in the bench includes the per-batch input repack;
+the user-facing `conv2d_nchw` calls `get_or_make_input_pack_cblocked`, which
+caches by `(storage_ptr, shape, dtype, ...)` so back-to-back calls with the
+same input tensor only repack once.
+
+### 5.4 `_conv2d_general_kernel`
+
+The fallback for everything that isn't 1×1, 3×3, or Winograd-eligible —
+i.e. any kernel size other than 1×1 or 3×3 (5×5, 7×7, dilated 5×5, etc.).
+Note: dilated 3×3 still routes to cblocked, not here. Strategy:
+
+- Pack weights once into K-major `W_K[K_out, K_pad]` where
+ `K_pad = pad(C·R·S, block_k)`. The trailing `K_pad − C·R·S` weight lanes
+ are zero so any reduction step in the tail contributes 0.
+- Tile the output as a GEMM over `M = N·P·Q` rows × `N = K_out` cols ×
+ `K = K_pad` reduction.
+- **No im2col buffer.** The kernel walks `K_pad` and decodes
+ `k → (c, r, s)` on the fly: `c = k // (R·S)`, `rs = k mod (R·S)`,
+ `r = rs // S`, `s = rs mod S`. Each `(c, r, s)` triple yields the input
+ coordinate `(h, w) = (sh·p − ph + dh·r, sw·q − pw + dw·s)` and a bounds
+ mask `(oh, ow) ∈ [0,H)×[0,W)` **plus a `c < C` channel-bound check**
+ (the K_pad tail decodes `c ≥ C`, so the input-side mask prevents OOB
+ reads — the zero-padded weight side just guarantees those phantom
+ contributions are 0). fp32 accumulator, downcast at store.
+- **`LAYOUT` constexpr** (0=NCHW, 1=NHWC) selects input/output strides; the
+ same kernel serves both layouts. The router calls it with `layout="nhwc"`
+ for non-1×1/3×3 NHWC shapes.
+- **Same L2 swizzle** as 5.1/5.2/5.3 — `GROUP_SIZE_M` super-grouping along M.
+
+### 5.5 The Winograd path — five kernels
+
+Winograd F(4×4, 3×3) reduces the **144** multiplies per 4×4 output tile of
+a direct 3×3 (16 outputs × 9 taps) to 36 multiplies *total* (one batched
+36-way GEMM) — **4× fewer multiplies** in the inner GEMM, modulo
+floating-point ordering. We pay for it in: (a) input/output transforms
+that touch every element with small constants, (b) increased numerical
+sensitivity in fp16/bf16. The math is in section 6; here are the five kernels:
+
+| Kernel | What it does | Output shape |
+|---|---|---|
+| `_winograd_f4x3_input_transform_kernel` | One 6×6 patch per `(t, channel-tile)`, computes `BᵀXB` in fp32, stores to V | `V[36, T, C_pad]` |
+| `_winograd_f4x3_cblocked_input_transform_kernel` | Same, but reads from NCHWc `[N, C_blocks, H, W, Cb]` for coalesced loads | `V[36, T, C_pad]` |
+| `_winograd_f4x3_batched_gemm_kernel` | 36 independent GEMMs `T × C_pad × K_out`, output `M[36, T, K_out]` | `M[36, T, K_out]` |
+| `_winograd_f4x3_output_transform_kernel` | One 6×6 `M` slice per `(t, k)`, computes `AᵀMA` (4×4), bias + activation, scatter to NCHW or NHWC output | `Y[N, K_out, P, Q]` |
+| `_winograd_f4x3_fused_gemm_output_kernel` | Streams the 36 GEMMs in 6 column-groups; applies the column transform `Aᵀ · M[:,col]` on-the-fly per group to accumulate 24 `s` values (skipping materialization of `M`), then applies the row transform once at the end. Replaces kernels 3 + 4 with a single launch at the cost of higher per-CTA register pressure (24 `BLOCK_T × BLOCK_K` fp32 accumulators). | `Y[N, K_out, P, Q]` |
+
+`conv2d_winograd_f4x3` and `conv2d_winograd_f4x3_cblocked` use the
+3-kernel pipeline (input transform + GEMM + output transform).
+`conv2d_winograd_f4x3_fused` uses the fused 2-kernel pipeline (input
+transform + fused GEMM/output). The fused variant trades a smaller
+intermediate (skips writing/reading `M[36, T, K_out]`) for higher per-CTA
+register pressure; on RDNA4 it's slightly faster on large `(T, K_out)` and
+slightly slower on small ones, which is why both exist in the registry but
+the auto-router picks only the unfused variants.
+
+The **filter transform** `U = G g Gᵀ` runs once on the host in
+`prepack_winograd_filter_f4x3` (in fp32), then is cast to the activation
+dtype and stored as `U[36, K_out, C_pad]`. The transform is per-weight, not
+per-call, so it amortizes across every forward pass.
+
+---
+
+## 6. Winograd F(4×4, 3×3) — full derivation
+
+Notation: `g` = 3×3 filter, `d` = 6×6 input tile, `Y` = 4×4 output tile.
+
+### 6.1 Why Winograd
+
+A direct 3×3 conv produces 4 output values from 6 input values along one
+axis using `4 × 3 = 12` multiplies. Winograd's minimal 1-D algorithm
+F(4, 3) needs only **6** multiplies for the same 4 outputs. Tensoring along
+both spatial dims gives F(4×4, 3×3): `4×4 = 16` outputs from a 6×6 input,
+in `6×6 = 36` element-wise products instead of `16 × 9 = 144`.
+
+Per output element: 144 / 16 = 9 muls direct, 36 / 16 = 2.25 muls Winograd
+→ **4× fewer multiplies** in the inner tensor product. The headline 4×
+ignores the input/output transforms (`BᵀdB`, `AᵀMA`), which aren't free
+and are why the realized speedup is smaller than 4×.
+
+**Reporting convention.** Both the bench harness and the published charts
+divide measured wall time by the *direct-convolution* FLOP count
+`2·N·K·C·R·S·H·W`, regardless of which algorithm actually ran. This is
+the universal convention — it gives a single yardstick on which different
+algorithms (direct, im2col-GEMM, Winograd, FFT) are comparable, since
+literal hardware-MAC counts are not. The implication: a Winograd row's
+TFLOPS overstates the literal MACs/sec the hardware executes by ~4× for
+F(4×4, 3×3) and ~2.25× for F(2×2, 3×3). MIOpen's Winograd solvers
+(`ConvBinWinogradRxSf3x2`, `ConvWinoFuryRxS<2-3>`) are reported the same
+way, so Triton-vs-MIOpen comparisons remain apples-to-apples.
+
+### 6.2 The transform matrices
+
+For F(4, 3) the standard Cook-Toom matrices (with
+`{0, ±1, ±2, ±½, ∞}` interpolation points) are:
+
+```
+ ┌ 6 0 0 ┐ ┌ 4 0 -5 0 1 0 ┐
+ │ -4 -4 -4 │ │ 0 -4 -4 1 1 0 │
+G = │ -4 4 -4 │ × (1/24) Bᵀ =│ 0 4 -4 -1 1 0 │
+ │ 1 2 4 │ │ 0 -2 -1 2 1 0 │
+ │ 1 -2 4 │ │ 0 2 -1 -2 1 0 │
+ └ 0 0 24┘ └ 0 4 0 -5 0 1 │ (×1, no scale)
+
+ ┌ 1 1 1 1 1 0 ┐
+ Aᵀ = │ 0 1 -1 2 -2 0 │
+ │ 0 1 1 4 4 0 │
+ └ 0 1 -1 8 -8 1 ┘ (4×6, scale 1)
+```
+
+(`_prepack.py:prepack_winograd_filter_f4x3` writes G with the `1/24` already
+absorbed: rows scaled by `1/4, -1/6, -1/6, 1/24, 1/24, 1` so a host-side
+`G g Gᵀ` produces `U` directly.)
+
+The math is then:
+
+```
+ U = G g Gᵀ (6×6, computed once per weight, host-side, fp32)
+ V = Bᵀ d B (6×6, computed per input tile, on-chip, fp32)
+ M = U ⊙ V (6×6, element-wise)
+ Y = Aᵀ M A (4×4, output tile)
+```
+
+**Batched.** With `T = N · ⌈P/4⌉ · ⌈Q/4⌉` tiles and `C` channels per tile,
+the per-tile element-wise product becomes a tile-batched GEMM:
+
+```
+ M[α, t, k] = Σ_c U[α, k, c] · V[α, t, c] α ∈ [0, 36)
+```
+
+— **36 independent GEMMs** of shape `(T, C) × (C, K_out)`. This is what the
+batched GEMM kernel does: `tl.program_id(1)` selects one of the 36 `α`
+slices and `tl.program_id(0)` selects a `(t, k)` tile (with the standard
+`GROUP_SIZE_M` L2 swizzle), then runs a normal Triton MM over `c`.
+
+### 6.3 What the kernels actually look like
+
+The transform matrices are **embedded in the JIT'd kernel as Python
+constants** (not loaded from memory), so the compiler can constant-fold and
+strength-reduce. The input transform `V = Bᵀ d B` is computed as the
+standard two-pass factorization — column transform first (`t = Bᵀ d`),
+then row transform (`V = t B`):
+
+```
+ # column transform (t = Bᵀ d), 36 values, each 4–6 fp32 ops
+ t00 = 4*d00 - 5*d20 + d40
+ t01 = 4*d01 - 5*d21 + d41
+ ...
+ # row transform (V = t B), 36 values, each 4–6 fp32 ops
+ v00 = 4*t00 - 5*t02 + t04
+ v01 = -4*t01 - 4*t02 + t03 + t04
+ ...
+```
+
+Each row of `t` is reused across all 6 columns of `V` in that row, so the
+two-pass form is ~72 small fp32 expressions total (36 `t` + 36 `v`) versus
+the much larger sum-of-products that a fully-expanded `Bᵀ d B` would
+unroll to. Input is loaded once per `(t, channel-tile)` as a 6×6 spatial
+block over `BLOCK_C` channels, cast to fp32, transformed, stored to
+`V[36, T, C_pad]`.
+
+The output transform is symmetric: read `M[36, t, k]` (6×6 per tile),
+compute `AᵀMA` (4×4), add bias, apply activation, scatter to `Y` (NCHW or
+NHWC, picked via a `LAYOUT` constexpr). The scatter has to write 16 values
+per tile to a strided output; this is where the kernel does the most index
+arithmetic.
+
+The **fused GEMM+output** kernel processes `M` column-by-column over its 6
+columns (not the 4 output columns). For each `col ∈ [0, 6)`, it runs 6
+mini-GEMMs (one per α-row at that column) and applies the column transform
+`Aᵀ` on-the-fly to compress the 6 GEMM results into 4 `s` values for that
+column. These accumulate into 24 (= 4 × 6) `BLOCK_T × BLOCK_K` fp32
+accumulators across the channel-tile loop. After the channel loop, a
+single row transform produces the 4×4 output. This skips the round-trip
+of `M` through global memory but needs many more live registers per
+program — hence the two variants.
+
+### 6.4 Numerical caveat
+
+The output transform `AᵀMA` has elements as large as `±8`. Rounding errors
+in `M` from the GEMM accumulation compound through the transform on both
+sides of `AᵀMA`. Row 3 of `Aᵀ` (`[0, 1, -1, 8, -8, 1]`) has L1 norm `19`;
+applied on both sides for the worst element `Y[3, 3]`, this bounds the
+error amplification at `19 × 19 = 361×`. On bf16 (mantissa = 7 bits,
+ε ≈ 2⁻⁷) that's enough to typically drop **3–4 bits of precision** vs the
+direct conv. Two consequences:
+
+1. The accumulators are **fp32 throughout** (Triton's `tl.dot` produces
+ fp32; we keep it fp32 through the output transform and only downcast at
+ store). This is non-negotiable.
+2. We disable Winograd for `C < 4` in `_is_winograd_eligible`, because at
+ `C=3` (RGB conv1) there aren't enough sum terms to absorb the
+ amplified rounding within the test tolerance.
+
+The test harness applies a **6× tolerance multiplier** to Winograd kernels
+(`_winograd_tolerances`); see §8.
+
+---
+
+## 7. Memory layouts and repacking
+
+There are four "shapes" the kernels actually consume, plus the user's NCHW
+input:
+
+| Layout | Kernel | Where the repack happens | Cached? |
+|---|---|---|---|
+| `[K_out, K_pad]` (K-major weight) | `general` | `_prepack.py:prepack_oihw_to_kmajor` | LRU 256, `_PACK_CACHE` |
+| `[K_out, 9, C_pad]` (3×3 weight) | `3x3_nhwc`, `3x3_cblocked` | `prepack_oihw_to_3x3` | LRU 256, `_PACK_CACHE_3x3` |
+| `[N, C_blocks, H, W, Cb]` (NCHWc input, `Cb=64`) | `3x3_cblocked`, `winograd_f4x3_cblocked` (input transform) | `prepack_nchw_to_cblocked` | **Single-entry dict** by design |
+| `[36, K_out, C_pad]` (Winograd weight) | `winograd_f4x3_*` | `prepack_winograd_filter_f4x3` | LRU 256, `_PACK_CACHE_WINOGRAD_F4X3` |
+
+### Why three weight caches are LRU but the input cache is single-entry
+
+Weights are reused every forward pass. A 53-layer ResNet has 53 unique
+weight tensors; with a 256-entry LRU we keep them all warm and the second
+forward pass through the model has zero repack cost.
+
+Input activations are *unique intermediate tensors* — the output of layer N
+is consumed exactly once, by layer N+1, and never seen again. Caching them
+LRU would just bloat memory; a single-entry dict is enough to dedup
+repeated calls with the same tensor (which the bench loop *does* do — it
+runs each layer with `warmup=15, rep=50` over the same input to get a
+stable timing). The `.clear()` in the bench's per-call setup deliberately
+models the per-batch repack cost in real inference, which is reported as
+the "kernel + repack" row.
+
+### Cache key safety
+
+Each cache key includes `(storage_ptr, shape, dtype, block_k, _version)`.
+The cache also stores a strong reference to the source tensor in the value,
+so its storage cannot be freed and re-used by a different tensor while the
+entry lives — this prevents `storage_ptr` collisions from producing a
+false hit. The single-entry input cache also re-checks the source pointer
+on hit (defense in depth).
+
+`AITER_TRITON_CONV_PACK_CACHE_SIZE=N` (env var) overrides the LRU bound.
+Default 256 is enough for the models exercised by the bench harness without
+eviction; larger models need the env override.
+
+---
+
+## 8. Numerical model and tolerances
+
+`_utils.py:dynamic_conv_tolerances` is the formula used by the test
+harness:
+
+```python
+eps = {fp16: 2**-10, bf16: 2**-7, fp32: 2**-23}[dtype]
+rtol = 6e-3 if K_red < 1024 else 8e-3 if K_red < 4096 else 1.2e-2
+atol = max(eps * 8, 10.0 * eps * sqrt(K_red))
+```
+
+Where `K_red = C · R · S`. Two facts behind this:
+
+1. **Multiplication of two ε-rounded inputs has relative error ε.**
+2. **Summing N independent rounding errors in fp32 grows as `ε · √N`** (random
+ walk), not `ε · N` (worst-case). Real kernels are not adversarial: the
+ `√N` bound matches what we observe across thousands of fuzzer shapes.
+
+The `10×` multiplier covers worst-case ordering differences between our
+Triton accumulation order and PyTorch's (different tile shape → different
+sum order → different rounding).
+
+`_winograd_tolerances` then bumps `rtol` by **6×** and `atol` by
+`max(6×, 0.6)` for any Winograd kernel — the `0.6` floor catches
+small-`K_red` cases where the √N-scaled base atol is too tight to absorb
+the F(4,3) amplification. The 6× factor is empirical: the analytical
+bound from §6.4 is 361× worst-case, but the *typical* amplification on
+natural images is much smaller because the worst-case tile structure is
+rare. 6× catches the long tail without flagging healthy fp16 ordering
+noise.
+
+If a new kernel needs the Winograd bump, mark it `is_winograd=True` in
+`op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`); the
+`_get_tolerances` dispatch picks the right tolerance automatically.
+
+---
+
+## 9. Method registry
+
+`op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`) is the **single source of truth**
+for kernel dispatch in the test harness. Adding a new method takes one entry:
+
+```python
+METHOD_REGISTRY = {
+ "default": MethodEntry(conv2d_nchw, None, False, "", "default"),
+ "cblocked": MethodEntry(conv2d_nchw_cblocked, _3x3_guard, False, "[cblocked]", "cblocked"),
+ "winograd_f4x3": MethodEntry(conv2d_winograd_f4x3, _wino_guard, True, "[winograd_f4x3]", "WF(4,3)"),
+ "winograd_f4x3_fused": MethodEntry(conv2d_winograd_f4x3_fused, _wino_guard, True, "[wino_f4x3_fused]", "WF4fused"),
+ "winograd_f4x3_cblocked": MethodEntry(conv2d_winograd_f4x3_cblocked, _wino_guard, True, "[winograd_f4x3_cblocked]", "WF4cb"),
+}
+```
+
+| Field | Purpose |
+|---|---|
+| `kernel_fn` | The public `conv2d_*` wrapper (not the raw kernel) — handles repack and grid setup, or routes to a specialized variant (`conv2d_nchw` does smart routing). |
+| `guard_fn(R, S, stride, dilation, C)` | Returns `True` if the method is applicable. `None` means "always". |
+| `is_winograd` | Selects the 6× tolerance bump in `_helpers._get_tolerances`. |
+| `bench_tag` | Suffix added to the per-test result name, e.g. `"[cblocked]"`. |
+| `short_name` | Reserved for future bench output; not used by `bench_conv2d.py` today. |
+
+Wiring downstream of the registry — the parametrized pytest tests, the
+tolerance dispatch — is automatic. You do not edit any of those.
+
+---
+
+## 10. Adding a new kernel — checklist
+
+Concretely, to add (say) a `winograd_f6x3` variant:
+
+1. **Implement the kernel.** New file under `aiter/ops/triton/_triton_kernels/conv/`.
+ Use the autotune config sets in `helpers.py` if a standard one fits; only
+ write a new one if a different shape really needs it.
+2. **Add a launch wrapper** in `_launch.py` (`_launch_winograd_f6x3`) that
+ sets up the grid and turns Python ints into Triton constexprs.
+3. **Add a public function** in `conv2d.py`
+ (`conv2d_winograd_f6x3(...)`). Compute output shape, allocate the
+ output tensor, fetch any prepacked weights from `_prepack.py`, call
+ the launch wrapper. Consumers import it directly from this module —
+ no `__init__.py` re-export needed.
+4. **Add a prepack** in `_prepack.py` if the kernel needs a different
+ weight layout. Use `_LRUPackCache` with a key that includes
+ `(storage_ptr, shape, dtype, block_k, _version)` and store a strong
+ ref to the source tensor in the cached value. For input-side prepacks
+ (rare — only `_PACK_CACHE_CBLOCKED` does this today), use a
+ single-entry plain dict cleared per-call instead — see §7.
+5. **Register** in `op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`). Set
+ `is_winograd=True` if the kernel uses Winograd-style transforms
+ (you'll need the 6× tolerance bump). If the kernel uses a different
+ Winograd tile (e.g. F(6,3)), also add a new `variant=` branch in
+ `_utils._winograd_tolerances` and route to it from
+ `_helpers._get_tolerances` — F(4,3)'s 6× bump
+ is calibrated specifically for that tile size. Also add the kernel to
+ `op_tests/op_benchmarks/triton/bench_conv2d.py`'s `METHODS` dict so
+ `--method ` works.
+6. **Optionally route** from `_select_3x3_method` if the new kernel
+ should be auto-selected. Update the heuristic comment block with the
+ shape range where it wins. If you do route from `_select_3x3_method`,
+ also add a `_last_triton_kernel = "..."` line in the matching branch
+ of `conv2d_nchw` / `conv2d_nhwc` so the bench labels the row
+ correctly.
+7. **Run the suite** — `pytest op_tests/triton_tests/conv/` parametrizes
+ `test_no_bias` and `test_activations` over every entry in
+ `METHOD_REGISTRY`, so the new method gets correctness coverage
+ automatically. Bench it via
+ `python -m op_tests.op_benchmarks.triton.bench_conv2d --method ...`.
+
+---
+
+## 11. Known limitations and future work
+
+- **`groups > 1`.** Kernels are `groups=1` only. Implementation entry
+ points: a new `aiter/ops/triton/_triton_kernels/conv/conv_depthwise.py`
+ for the depthwise case, K-major-per-group prepack in `_prepack.py`,
+ route from `_select_3x3_method`.
+- **`padding_mode != "zeros"`.** Reflect / replicate / circular padding
+ is unsupported. Adding it is mostly a per-kernel change to the
+ bounds-mask logic — replace `mask = (h ≥ 0) & (h < H) & ...` with a
+ reflected/wrapped index computation.
+- **Backward / training.** No gradients. Adding training would mean a
+ second set of kernels (`grad_input`, `grad_weight`) and is out of
+ scope for this library's intended use as an inference replacement.
diff --git a/aiter/ops/triton/conv/README.md b/aiter/ops/triton/conv/README.md
new file mode 100644
index 0000000000..4ba1bc59f1
--- /dev/null
+++ b/aiter/ops/triton/conv/README.md
@@ -0,0 +1,229 @@
+# conv2d (Triton, AMD ROCm)
+
+> **`Conv2d` for AMD ROCm — a drop-in replacement for `torch.nn.Conv2d`,
+> optimized for AMD RDNA GPUs.**
+
+[](LICENSE)
+[](https://www.python.org)
+[](https://pytorch.org)
+[](https://www.amd.com/en/developer/resources/rocm-hub.html)
+[](https://github.com/triton-lang/triton)
+
+A hand-written Triton 2-D convolution library optimized for AMD RDNA
+GPUs. Five kernel families (1×1, 3×3 cblocked, 3×3 NHWC, Winograd
+F(4×4, 3×3), general) behind one shape-driven router and one entry
+point. Drop-in for the forward path of `nn.Conv2d`.
+
+---
+
+## Why this op exists
+
+PyTorch on AMD goes through MIOpen, which ships hand-tuned solvers per
+architecture, per dtype, per layout. That works well on the combinations
+the solvers were specifically tuned for, but every new dtype × layout ×
+architecture combination needs its own tuning pass — so coverage is
+uneven across the matrix (e.g. on RDNA4 the fp16 path is well-served,
+while bf16 falls back to direct/GEMM solvers that are noticeably slower
+at large channel counts; most modern checkpoints — LLMs, diffusion VAEs
+— ship in bf16).
+
+This op takes the opposite approach: a single set of Triton kernels
+that runs **fp16 and bf16 through the same code path**, supports
+**both NCHW and NHWC end-to-end** (NHWC inputs run on an NHWC kernel —
+no NHWC↔NCHW conversion), and gets reasonable performance across the
+full matrix **without per-architecture hand tuning**. A shape-driven
+router picks between five kernel families (1×1, 3×3 cblocked, 3×3 NHWC,
+Winograd F(4×4, 3×3), general) so the right kernel runs per layer
+automatically. Some kernels do repack inputs/weights into kernel-local
+formats (channel-blocked tiles for cblocked, G/Bᵀ transforms for
+Winograd) — these packs are LRU-cached so steady-state cost is
+negligible.
+
+---
+
+## Performance
+
+Designed to deliver strong throughput on AMD RDNA4 across both fp16 and
+bf16. The bench harness logs which MIOpen solver was selected per layer
+and reports aggregate TFLOPS for both backends, so you can verify the
+behavior on your stack.
+
+To measure on your stack:
+
+```bash
+python -m op_tests.op_benchmarks.triton.bench_conv2d \
+ --model \
+ --dtype \
+ [--miopen-solvers] # opt-in; ~60-120s upfront subprocess
+```
+
+The bench harness produces three box-drawn tables: LAYER-BY-LAYER (per-layer
+Triton vs MIOpen TFLOPS, Triton kernel name, optionally MIOpen solver,
+kernel+repack column for shapes that prepack), MIOpen SOLVER SUMMARY (only
+with `--miopen-solvers`), and OVERALL PERFORMANCE (mean/median/aggregate
+TFLOPS, total time, layer wins, correctness).
+
+> **Note on TFLOPS**: numbers are *direct-convolution-equivalent* throughput
+> (the standard convention used by cuDNN, MIOpen, and the Winograd
+> literature), applied identically to both backends. Winograd kernels —
+> Triton's F(4×4, 3×3) and MIOpen's F(2×2, 3×3) / Fury alike — execute
+> fewer literal hardware MACs than this denominator counts (≈4× fewer for
+> F(4,3), ≈2.25× for F(2,3)). The comparison is apples-to-apples.
+
+---
+
+## Quick start
+
+### Use the function directly
+
+```python
+import torch
+from aiter.ops.triton.conv.conv2d import conv2d
+
+x = torch.randn(4, 256, 56, 56, device="cuda", dtype=torch.float16)
+w = torch.randn(512, 256, 3, 3, device="cuda", dtype=torch.float16)
+
+y = conv2d(
+ x, w, bias=None,
+ stride=(1, 1), padding=(1, 1), dilation=(1, 1),
+ activation="relu", # "none" | "relu" | "relu6" | "gelu"
+ out_dtype=None, # None → match input dtype (default)
+ layout="nchw", # "nchw" or "nhwc"
+)
+```
+
+A shape-driven router picks one of five kernel families:
+
+| Family | When it runs |
+|---|---|
+| 1×1 GEMM | `R==1, S==1` |
+| 3×3 cblocked (NCHW) | 3×3, channel-blocked input for coalesced loads |
+| 3×3 NHWC | 3×3 with channels-last input — no input repack |
+| Winograd F(4×4, 3×3) | 3×3, stride=1, dilation=1, `C ≥ 512`, `K ≥ 512`, enough output tiles |
+| General | anything not 1×1 or 3×3 (5×5, 7×7, dilated, strided) |
+
+### Use as `nn.Conv2d` drop-in
+
+The kernel families above are functional; wrapping them in an `nn.Module`
+(walk a model, swap each `nn.Conv2d` for a Triton-backed module that
+calls `conv2d(...)` in its `forward`) works as expected and produces
+images visually indistinguishable from the PyTorch / MIOpen reference.
+
+Pixel-level agreement on FLUX.2-klein-9B (50 diffusion steps, same prompt
+and seed under both backends, only VAE convs swapped to Triton): max diff
+**6 / 255**, mean diff **0.17 / 255**. See `examples/flux2_inference.py`
+to reproduce and inspect the generated images.
+
+---
+
+## Constraints
+
+- `groups` must equal 1 (depthwise / grouped not yet implemented).
+- `padding_mode` must be `"zeros"`. The pad *amount* (`padding=`, e.g.
+ `(1, 1)` or asymmetric `(0, 2)`) is unrestricted; only the pad *value*
+ is — `"reflect"`, `"replicate"`, and `"circular"` fall back to PyTorch /
+ MIOpen.
+- Inputs must be `fp16` or `bf16`.
+- Forward only (no backward / training).
+
+---
+
+## Reproducing the tests and benchmarks
+
+Run from the AITER repo root (`/app/aiter` in this tree, or `PYTHONPATH=/app/aiter`).
+
+### Correctness (CDNA CI runs this)
+
+```bash
+pytest op_tests/triton_tests/conv/ # full matrix, 74 tests
+pytest op_tests/triton_tests/conv/ -k "no_bias and fp16_nchw" # subset
+pytest op_tests/triton_tests/conv/ -k "test_edge" # one test family
+```
+
+Tests are parametrized over `(dtype, layout, method)`. Every kernel in
+`_helpers.ORDERED_METHODS` is exercised against fp16 and bf16 on NCHW.
+NHWC is single-dispatch (only `conv2d_nhwc`), so each NHWC test runs once
+per dtype.
+
+### Benchmark
+
+Three modes, all in `bench_conv2d.py`.
+
+**Single shape** (one parseable result line — for ad-hoc measurements):
+
+```bash
+python -m op_tests.op_benchmarks.triton.bench_conv2d \
+ --N 1 --C 64 --H 56 --W 56 --K 64 --R 3 --S 3 --pad-h 1 --pad-w 1
+```
+
+**Built-in 12-shape sweep** (no model required, three box-drawn tables):
+
+```bash
+python -m op_tests.op_benchmarks.triton.bench_conv2d --dtype fp16
+```
+
+**Real-model sweep** (reads shapes from `model_shapes.json`):
+
+```bash
+python -m op_tests.op_benchmarks.triton.bench_conv2d --model resnet50
+python -m op_tests.op_benchmarks.triton.bench_conv2d --model "FLUX.2" --miopen-solvers
+```
+
+Cross-axis flags:
+
+```
+--dtype {fp16,bf16} # default fp16
+--layout {nchw,nhwc} # default nchw
+--method {auto,default,cblocked,nhwc,winograd_f4x3,winograd_f4x3_fused,winograd_f4x3_cblocked}
+--metric {time,throughput} # default throughput
+--no-bias # bench the bias=None code path
+--miopen-solvers # detect MIOpen solver names (sweep mode; ~60-120s subprocess)
+--show-kernel-name # include routed kernel name in single-shape output
+```
+
+Real-model shapes are pre-extracted (no torchvision/diffusers needed at
+bench time). To add a new model, run `extract_conv_shapes.py` once and
+merge its JSON output into `model_shapes.json`:
+
+```bash
+python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \
+ --model resnet50 # or sd35_vae / flux2_vae with --model-path
+```
+
+Tested on ROCm 7.2 / PyTorch `2.9.1+gitff65f5b` / Triton 3.7 (commit `23f4e522d`).
+
+---
+
+## Documentation
+
+- **[`DESIGN.md`](DESIGN.md)** — architecture, per-kernel deep-dive, full
+ Winograd F(4,3) derivation (G/Bᵀ/Aᵀ matrices, 361× amplification
+ analysis, why Winograd is disabled for `C < 4`), the
+ `_select_3x3_method` heuristic, memory layouts and repacking,
+ numerical model, extension guide.
+
+---
+
+## Repository layout
+
+```
+aiter/ops/triton/conv/ Kernel library
+ conv2d.py Public API + smart routing
+ _launch.py Grid setup + _select_3x3_method
+ _prepack.py Weight/input repack caches (LRU)
+ _utils.py Shape math, tolerance model
+ README.md, DESIGN.md
+
+aiter/ops/triton/_triton_kernels/conv/ @triton.jit kernels
+ (1x1, 3x3 cblocked, 3x3 NHWC, general, 5 Winograd kernels)
+
+op_tests/triton_tests/conv/ Pytest unit tests (CDNA CI runs this)
+ test_conv2d.py The only collected test file
+ _helpers.py TestSuite, registry, shape generators
+
+op_tests/op_benchmarks/triton/
+ bench_conv2d.py Self-contained bench tool (single + sweep)
+ model_benchmarking_tool/
+ extract_conv_shapes.py One-time offline shape extraction
+ model_shapes.json Pre-extracted conv shapes (resnet50, SD3.5, FLUX2)
+```
diff --git a/aiter/ops/triton/conv/__init__.py b/aiter/ops/triton/conv/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/aiter/ops/triton/conv/_launch.py b/aiter/ops/triton/conv/_launch.py
new file mode 100644
index 0000000000..e71b1d5171
--- /dev/null
+++ b/aiter/ops/triton/conv/_launch.py
@@ -0,0 +1,579 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+except Exception:
+ triton = None
+ tl = None
+
+from aiter.ops.triton.conv._utils import _out_hw, _is_winograd_eligible
+from aiter.ops.triton._triton_kernels.conv.conv_1x1 import _conv2d_1x1_kernel
+from aiter.ops.triton._triton_kernels.conv.conv_general import _conv2d_general_kernel
+from aiter.ops.triton._triton_kernels.conv.conv_3x3 import (
+ _conv2d_3x3_nhwc_kernel,
+ _conv2d_3x3_cblocked_kernel,
+)
+from aiter.ops.triton._triton_kernels.conv.conv_3x3_winograd_f4x3 import (
+ _winograd_f4x3_input_transform_kernel,
+ _winograd_f4x3_cblocked_input_transform_kernel,
+ _winograd_f4x3_batched_gemm_kernel,
+ _winograd_f4x3_output_transform_kernel,
+ _winograd_f4x3_fused_gemm_output_kernel,
+)
+
+
+def _torch_dtype_to_tl(dtype):
+ """Map torch dtype to triton dtype for constexpr params."""
+ if dtype == torch.float16:
+ return tl.float16
+ elif dtype == torch.bfloat16:
+ return tl.bfloat16
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+def _select_3x3_method(N, C, H, W, K_out, stride, dilation):
+ """Pick the best 3x3 kernel method based on shape heuristics.
+
+ Decision tree (from benchmark sweep on RDNA4):
+ 1. Non-Winograd-eligible (stride>1, dilation>1, or C<4) -> cblocked
+ 2. Winograd only wins when BOTH C and K >= 512 with enough tiles (T >= 98).
+ At 256x256 channels, cblocked is tied or slightly better.
+ 3. Among Winograd variants: WF4cb (NCHWc input) beats WF4 (NCHW input)
+ when T >= 392 (large batch * spatial gives more coalescing benefit).
+ Below that, WF4 is slightly faster (less repacking overhead).
+ """
+ if not _is_winograd_eligible(3, 3, stride, dilation, C):
+ return "cblocked"
+ P, Q = _out_hw(H, W, 3, 3, stride, (1, 1), dilation)
+ tile_H = (P + 3) // 4
+ tile_W = (Q + 3) // 4
+ T = N * tile_H * tile_W
+ if C >= 512 and K_out >= 512 and T >= 98:
+ if T >= 392:
+ return "winograd_f4x3_cblocked"
+ return "winograd_f4x3"
+ return "cblocked"
+
+
+def _layout_to_int(layout):
+ """Convert layout string to kernel int: 0=NCHW, 1=NHWC."""
+ layout = layout.lower()
+ if layout not in ("nchw", "nhwc"):
+ raise ValueError(f"layout must be 'nchw' or 'nhwc', got '{layout}'")
+ return 0 if layout == "nchw" else 1
+
+
+def _launch_1x1(
+ x,
+ w_oihw,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ stride,
+ padding,
+ activation,
+ layout="nchw",
+):
+ """Launch specialized 1x1 kernel.
+ layout: "nchw" or "nhwc" (case-insensitive).
+ """
+ if triton is None:
+ raise RuntimeError("Triton not available")
+
+ sh, sw = stride
+ ph, pw = padding
+
+ w = w_oihw.squeeze(-1).squeeze(-1).contiguous() # [K_out, C]
+ layout = _layout_to_int(layout)
+
+ def grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),)
+
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else w.new_empty(1)
+
+ M_total = N * P * Q
+
+ _conv2d_1x1_kernel[grid](
+ x,
+ w,
+ bias_arg,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ sh,
+ sw,
+ ph,
+ pw,
+ M_total,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ LAYOUT=layout,
+ )
+
+
+def _launch_3x3_nhwc(
+ x,
+ w_3x3,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ stride,
+ padding,
+ dilation,
+ activation,
+):
+ """Launch specialized 3x3 NHWC kernel (hardcoded stride_c=1, stride_k=1)."""
+ if triton is None:
+ raise RuntimeError("Triton not available")
+
+ sh, sw = stride
+ ph, pw = padding
+ dh, dw = dilation
+
+ def grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),)
+
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else w_3x3.new_empty(1)
+
+ M_total = N * P * Q
+
+ _conv2d_3x3_nhwc_kernel[grid](
+ x,
+ w_3x3,
+ bias_arg,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ sh,
+ sw,
+ ph,
+ pw,
+ dh,
+ dw,
+ M_total,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ )
+
+
+def _launch_3x3_cblocked(
+ x_blocked,
+ w_3x3,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ Cb,
+ stride,
+ padding,
+ dilation,
+ activation,
+):
+ """Launch specialized 3x3 kernel for channel-blocked input."""
+ if triton is None:
+ raise RuntimeError("Triton not available")
+
+ sh, sw = stride
+ ph, pw = padding
+ dh, dw = dilation
+
+ def grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),)
+
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else w_3x3.new_empty(1)
+
+ M_total = N * P * Q
+
+ _conv2d_3x3_cblocked_kernel[grid](
+ x_blocked,
+ w_3x3,
+ bias_arg,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ Cb,
+ sh,
+ sw,
+ ph,
+ pw,
+ dh,
+ dw,
+ M_total,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ )
+
+
+def _launch_general(
+ x,
+ w_k,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ R,
+ S,
+ P,
+ Q,
+ K_pad,
+ stride,
+ padding,
+ dilation,
+ block_k,
+ activation,
+ layout="nchw",
+):
+ """Launch general conv kernel.
+ layout: "nchw" or "nhwc" (case-insensitive).
+ """
+ if triton is None:
+ raise RuntimeError("Triton not available")
+
+ sh, sw = stride
+ ph, pw = padding
+ dh, dw = dilation
+ layout = _layout_to_int(layout)
+
+ def grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),)
+
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else w_k.new_empty(1)
+
+ M_total = N * P * Q
+
+ _conv2d_general_kernel[grid](
+ x,
+ w_k,
+ bias_arg,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ R,
+ S,
+ P,
+ Q,
+ K_pad,
+ sh,
+ sw,
+ ph,
+ pw,
+ dh,
+ dw,
+ M_total,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ LAYOUT=layout,
+ )
+
+
+def _launch_winograd_f4x3_fused(
+ x,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ layout="nchw",
+):
+ """Launch Winograd F(4x4,3x3) with fused GEMM+output transform (2 kernels instead of 3)."""
+ if triton is None:
+ raise RuntimeError("Triton not available")
+ ph, pw = padding
+ tile_H = (P + 3) // 4
+ tile_W = (Q + 3) // 4
+ T = N * tile_H * tile_W
+ layout_int = _layout_to_int(layout)
+
+ input_dtype = x.dtype
+ V = torch.empty((36, T, C_pad), device=x.device, dtype=input_dtype)
+
+ # 1. Input transform
+ def input_grid_f4(meta):
+ return (T, triton.cdiv(C_pad, meta["BLOCK_C"]))
+
+ _winograd_f4x3_input_transform_kernel[input_grid_f4](
+ x,
+ V,
+ N,
+ C,
+ C_pad,
+ H,
+ W_in,
+ tile_H,
+ tile_W,
+ T,
+ ph,
+ pw,
+ INPUT_DTYPE=_torch_dtype_to_tl(input_dtype),
+ LAYOUT=layout_int,
+ )
+
+ # 2. Fused GEMM + output transform
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else x.new_empty(1)
+
+ def fused_grid_f4(meta):
+ return (triton.cdiv(T, meta["BLOCK_T"]), triton.cdiv(K_out, meta["BLOCK_K"]))
+
+ _winograd_f4x3_fused_gemm_output_kernel[fused_grid_f4](
+ V,
+ U,
+ bias_arg,
+ y,
+ N,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ tile_H,
+ tile_W,
+ T,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ LAYOUT=layout_int,
+ )
+
+
+def _launch_winograd_f4x3(
+ x,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ layout="nchw",
+):
+ """Launch Winograd F(4x4,3x3) pipeline: input transform -> batched GEMM -> output transform."""
+ if triton is None:
+ raise RuntimeError("Triton not available")
+ ph, pw = padding
+ tile_H = (P + 3) // 4
+ tile_W = (Q + 3) // 4
+ T = N * tile_H * tile_W
+ layout_int = _layout_to_int(layout)
+
+ input_dtype = x.dtype
+ V = torch.empty((36, T, C_pad), device=x.device, dtype=input_dtype)
+ M = torch.empty((36, T, K_out), device=x.device, dtype=torch.float32)
+
+ # 1. Input transform
+ def input_grid_f4(meta):
+ return (T, triton.cdiv(C_pad, meta["BLOCK_C"]))
+
+ _winograd_f4x3_input_transform_kernel[input_grid_f4](
+ x,
+ V,
+ N,
+ C,
+ C_pad,
+ H,
+ W_in,
+ tile_H,
+ tile_W,
+ T,
+ ph,
+ pw,
+ INPUT_DTYPE=_torch_dtype_to_tl(input_dtype),
+ LAYOUT=layout_int,
+ )
+
+ # 2. Batched GEMM
+ def gemm_grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(T, BM) * triton.cdiv(K_out, BN), 36)
+
+ _winograd_f4x3_batched_gemm_kernel[gemm_grid](
+ V,
+ U,
+ M,
+ T,
+ K_out,
+ C_pad,
+ )
+
+ # 3. Output transform
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else x.new_empty(1)
+
+ def output_grid_f4(meta):
+ return (T, triton.cdiv(K_out, meta["BLOCK_K"]))
+
+ _winograd_f4x3_output_transform_kernel[output_grid_f4](
+ M,
+ bias_arg,
+ y,
+ N,
+ K_out,
+ P,
+ Q,
+ tile_H,
+ tile_W,
+ T,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ LAYOUT=layout_int,
+ )
+
+
+def _launch_winograd_f4x3_cblocked(
+ x_blocked,
+ C_pad_blocked,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ block_k,
+):
+ """Launch Winograd F(4x4,3x3) with NCHWc input layout: cblocked input transform -> batched GEMM -> output transform."""
+ if triton is None:
+ raise RuntimeError("Triton not available")
+ ph, pw = padding
+ tile_H = (P + 3) // 4
+ tile_W = (Q + 3) // 4
+ T = N * tile_H * tile_W
+
+ Cb = block_k
+ input_dtype = x_blocked.dtype
+ V = torch.empty((36, T, C_pad), device=x_blocked.device, dtype=input_dtype)
+ M = torch.empty((36, T, K_out), device=x_blocked.device, dtype=torch.float32)
+
+ # 1. Cblocked input transform
+ def input_grid_f4(meta):
+ return (T, triton.cdiv(C_pad, meta["BLOCK_C"]))
+
+ _winograd_f4x3_cblocked_input_transform_kernel[input_grid_f4](
+ x_blocked,
+ V,
+ N,
+ C,
+ C_pad,
+ H,
+ W_in,
+ tile_H,
+ tile_W,
+ T,
+ ph,
+ pw,
+ Cb,
+ INPUT_DTYPE=_torch_dtype_to_tl(input_dtype),
+ )
+
+ def gemm_grid(meta):
+ BM = meta["BLOCK_M"]
+ BN = meta["BLOCK_N"]
+ return (triton.cdiv(T, BM) * triton.cdiv(K_out, BN), 36)
+
+ _winograd_f4x3_batched_gemm_kernel[gemm_grid](
+ V,
+ U,
+ M,
+ T,
+ K_out,
+ C_pad,
+ )
+
+ ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3}
+ bias_arg = bias_fp32 if bias_fp32 is not None else x_blocked.new_empty(1)
+
+ def output_grid_f4(meta):
+ return (T, triton.cdiv(K_out, meta["BLOCK_K"]))
+
+ _winograd_f4x3_output_transform_kernel[output_grid_f4](
+ M,
+ bias_arg,
+ y,
+ N,
+ K_out,
+ P,
+ Q,
+ tile_H,
+ tile_W,
+ T,
+ HAS_BIAS=1 if bias_fp32 is not None else 0,
+ ACT_TYPE=ACT_MAP.get(activation, 0),
+ )
diff --git a/aiter/ops/triton/conv/_prepack.py b/aiter/ops/triton/conv/_prepack.py
new file mode 100644
index 0000000000..d586034a56
--- /dev/null
+++ b/aiter/ops/triton/conv/_prepack.py
@@ -0,0 +1,197 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import os
+from collections import OrderedDict
+from typing import Dict
+import torch
+
+from aiter.ops.triton.conv._utils import BLOCK_K, _storage_ptr
+
+_PACK_CACHE_MAXSIZE = int(os.environ.get("AITER_TRITON_CONV_PACK_CACHE_SIZE", "256"))
+
+
+class _LRUPackCache:
+ """Bounded LRU for weight prepacks. Stores (src_tensor, item) — the
+ strong ref to src keeps storage alive so the storage_ptr in the key
+ cannot be reused by a different tensor while this entry lives."""
+
+ def __init__(self, maxsize: int = _PACK_CACHE_MAXSIZE):
+ self._d: "OrderedDict[tuple, tuple]" = OrderedDict()
+ self._max = max(1, maxsize)
+
+ def get(self, key):
+ entry = self._d.get(key)
+ if entry is None:
+ return None
+ self._d.move_to_end(key)
+ return entry
+
+ def put(self, key, src, item):
+ self._d[key] = (src, item)
+ self._d.move_to_end(key)
+ while len(self._d) > self._max:
+ self._d.popitem(last=False)
+
+ def clear(self):
+ self._d.clear()
+
+ def __len__(self):
+ return len(self._d)
+
+
+_PACK_CACHE = _LRUPackCache()
+_PACK_CACHE_3x3 = _LRUPackCache()
+# Input pack — kept as single-entry dict by design: in real inference each
+# layer's input is a unique intermediate activation that won't be reused,
+# and the bench clears this per-call to model per-batch repack cost.
+_PACK_CACHE_CBLOCKED: Dict = {}
+_PACK_CACHE_WINOGRAD_F4X3 = _LRUPackCache()
+
+
+def prepack_oihw_to_kmajor(w_oihw: torch.Tensor, block_k: int = BLOCK_K):
+ K_out, C, R, S = w_oihw.shape
+ K_red = C * R * S
+ K_pad = ((K_red + block_k - 1) // block_k) * block_k
+ w_rs = w_oihw.reshape(K_out, K_red)
+ if K_pad != K_red:
+ pad = torch.zeros(
+ (K_out, K_pad - K_red), device=w_oihw.device, dtype=w_oihw.dtype
+ )
+ w_rs = torch.cat([w_rs, pad], dim=1)
+ return w_rs.contiguous(), (K_out, K_pad)
+
+
+def get_or_make_weight_pack(w_oihw: torch.Tensor, block_k: int = BLOCK_K):
+ key = (
+ _storage_ptr(w_oihw),
+ tuple(w_oihw.shape),
+ w_oihw.dtype,
+ block_k,
+ int(getattr(w_oihw, "_version", 0)),
+ )
+ entry = _PACK_CACHE.get(key)
+ if entry is not None:
+ return entry[1]
+ item = prepack_oihw_to_kmajor(w_oihw, block_k)
+ _PACK_CACHE.put(key, w_oihw, item)
+ return item
+
+
+def prepack_oihw_to_3x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K):
+ """Pack weights as [K_out, 9, C_pad] for 3x3 specialized kernel."""
+ K_out, C, R, S = w_oihw.shape
+ assert R == 3 and S == 3
+ C_pad = ((C + block_c - 1) // block_c) * block_c
+ w_rs = w_oihw.reshape(K_out, C, 9).permute(0, 2, 1).contiguous() # [K_out, 9, C]
+ if C_pad != C:
+ pad = torch.zeros(
+ (K_out, 9, C_pad - C), device=w_oihw.device, dtype=w_oihw.dtype
+ )
+ w_rs = torch.cat([w_rs, pad], dim=2)
+ return w_rs.contiguous(), (K_out, C_pad)
+
+
+def get_or_make_weight_pack_3x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K):
+ key = (
+ _storage_ptr(w_oihw),
+ tuple(w_oihw.shape),
+ w_oihw.dtype,
+ block_c,
+ int(getattr(w_oihw, "_version", 0)),
+ )
+ cached = _PACK_CACHE_3x3.get(key)
+ if cached is not None:
+ return cached[1]
+ item = prepack_oihw_to_3x3(w_oihw, block_c)
+ _PACK_CACHE_3x3.put(key, w_oihw, item)
+ return item
+
+
+def prepack_nchw_to_cblocked(x: torch.Tensor, block_c: int = BLOCK_K):
+ """Pack NCHW input into channel-blocked layout [N, C_blocks, H, W, Cb].
+
+ Within each block of Cb channels, data is contiguous (stride=1).
+ """
+ N, C, H, W = x.shape
+ Cb = block_c
+ C_blocks = (C + Cb - 1) // Cb
+ C_pad = C_blocks * Cb
+
+ if C_pad != C:
+ x_padded = torch.zeros((N, C_pad, H, W), device=x.device, dtype=x.dtype)
+ x_padded[:, :C, :, :] = x
+ else:
+ x_padded = x
+
+ x_blocked = (
+ x_padded.reshape(N, C_blocks, Cb, H, W).permute(0, 1, 3, 4, 2).contiguous()
+ )
+ return x_blocked, C_pad
+
+
+def get_or_make_input_pack_cblocked(x: torch.Tensor, block_c: int = BLOCK_K):
+ key = (
+ _storage_ptr(x),
+ tuple(x.shape),
+ x.dtype,
+ block_c,
+ int(getattr(x, "_version", 0)),
+ )
+ cached = _PACK_CACHE_CBLOCKED.get(key)
+ if cached is not None:
+ src_ref, item = cached
+ if src_ref is not None and _storage_ptr(src_ref) == key[0]:
+ return item
+ item = prepack_nchw_to_cblocked(x, block_c)
+ _PACK_CACHE_CBLOCKED.clear()
+ _PACK_CACHE_CBLOCKED[key] = (x, item)
+ return item
+
+
+def prepack_winograd_filter_f4x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K):
+ """Transform 3x3 filters for Winograd F(4x4,3x3). G @ g @ G^T for each (k,c).
+ Input: [K_out, C, 3, 3] fp16. Output: [36, K_out, C_pad] fp16."""
+ K_out, C, R, S = w_oihw.shape
+ assert R == 3 and S == 3
+ C_pad = ((C + block_c - 1) // block_c) * block_c
+ # G matrix (6x3)
+ G = torch.tensor(
+ [
+ [1.0 / 4, 0.0, 0.0],
+ [-1.0 / 6, -1.0 / 6, -1.0 / 6],
+ [-1.0 / 6, 1.0 / 6, -1.0 / 6],
+ [1.0 / 24, 1.0 / 12, 1.0 / 6],
+ [1.0 / 24, -1.0 / 12, 1.0 / 6],
+ [0.0, 0.0, 1.0],
+ ],
+ dtype=torch.float32,
+ device=w_oihw.device,
+ )
+
+ g = w_oihw.float() # [K_out, C, 3, 3]
+ u = torch.einsum("ij,kcjl,lm->kcim", G, g, G.t())
+ u = u.reshape(K_out, C, 36).permute(2, 0, 1).contiguous()
+ if C_pad != C:
+ pad = torch.zeros(
+ (36, K_out, C_pad - C), device=w_oihw.device, dtype=torch.float32
+ )
+ u = torch.cat([u, pad], dim=2)
+ return u.to(w_oihw.dtype).contiguous(), (K_out, C_pad)
+
+
+def get_or_make_winograd_filter_f4x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K):
+ key = (
+ _storage_ptr(w_oihw),
+ tuple(w_oihw.shape),
+ w_oihw.dtype,
+ block_c,
+ int(getattr(w_oihw, "_version", 0)),
+ )
+ cached = _PACK_CACHE_WINOGRAD_F4X3.get(key)
+ if cached is not None:
+ return cached[1]
+ item = prepack_winograd_filter_f4x3(w_oihw, block_c)
+ _PACK_CACHE_WINOGRAD_F4X3.put(key, w_oihw, item)
+ return item
diff --git a/aiter/ops/triton/conv/_utils.py b/aiter/ops/triton/conv/_utils.py
new file mode 100644
index 0000000000..9d0850e8cf
--- /dev/null
+++ b/aiter/ops/triton/conv/_utils.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+import torch
+import torch.nn.functional as F
+
+# Channel padding granularity for prepacked weights/inputs. Must align with the
+# BLOCK_K autotune candidates in _triton_kernels/conv/helpers.py — change with care.
+BLOCK_K = 64
+
+
+def dynamic_conv_tolerances(dtype: torch.dtype, K_red: int, ref: torch.Tensor):
+ eps = {
+ torch.float16: 2**-10,
+ torch.bfloat16: 2**-7,
+ torch.float32: 2**-23,
+ }.get(dtype, 2**-10)
+ rtol = 6e-3 if K_red < 1024 else (8e-3 if K_red < 4096 else 1.2e-2)
+ # Error model: fp16 inputs multiplied pairwise have eps relative error per product.
+ # Accumulated in fp32 over K_red terms, max absolute error grows as ~eps * sqrt(K_red).
+ # The 10x multiplier covers worst-case accumulation ordering differences
+ # between our Triton kernels and PyTorch reference.
+ atol = max(eps * 8, 10.0 * eps * (K_red**0.5))
+ return rtol, atol
+
+
+def flops_conv(N, C, K_out, R, S, P, Q):
+ return 2.0 * N * P * Q * K_out * C * R * S
+
+
+def _out_hw(H, W, R, S, stride, padding, dilation):
+ sh, sw = stride
+ ph, pw = padding
+ dh, dw = dilation
+ P = (H + 2 * ph - dh * (R - 1) - 1) // sh + 1
+ Q = (W + 2 * pw - dw * (S - 1) - 1) // sw + 1
+ return P, Q
+
+
+def _storage_ptr(t: torch.Tensor) -> int:
+ return (
+ t.untyped_storage().data_ptr()
+ if hasattr(t, "untyped_storage")
+ else t.storage().data_ptr()
+ )
+
+
+def _is_1x1_conv(R, S, dilation):
+ """Check if this is a 1x1 convolution (no spatial reduction in kernel)."""
+ return R == 1 and S == 1 and dilation == (1, 1)
+
+
+def _is_3x3_conv(R, S):
+ """Check if this is a 3x3 convolution."""
+ return R == 3 and S == 3
+
+
+def _is_winograd_eligible(R, S, stride, dilation, C=None):
+ if not (R == 3 and S == 3 and stride == (1, 1) and dilation == (1, 1)):
+ return False
+ # F(4,3) output transform amplifies bf16 rounding by up to 361x (AT row3 L1=19).
+ # With very few input channels the tolerance budget is too small to absorb this.
+ if C is not None and C < 4:
+ return False
+ return True
+
+
+def _winograd_tolerances(dtype, K_red, ref, variant="f4x3"):
+ """Return (rtol, atol) for Winograd F(4x4,3x3) correctness checks.
+ Winograd transforms amplify fp16 rounding errors:
+ - F(4x4,3x3): coefficients up to ±8, significant amplification
+ """
+ rtol, atol = dynamic_conv_tolerances(dtype, K_red, ref)
+ if variant == "f4x3":
+ rtol *= 6.0
+ atol = max(atol * 6.0, 0.6)
+ return rtol, atol
+
+
+def apply_activation(y: torch.Tensor, activation: str):
+ if activation == "relu":
+ return F.relu(y)
+ if activation == "relu6":
+ return torch.clamp(y, 0, 6)
+ if activation == "gelu":
+ return F.gelu(y, approximate="tanh")
+ return y
diff --git a/aiter/ops/triton/conv/conv2d.py b/aiter/ops/triton/conv/conv2d.py
new file mode 100644
index 0000000000..afda01cbf9
--- /dev/null
+++ b/aiter/ops/triton/conv/conv2d.py
@@ -0,0 +1,640 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+
+from __future__ import annotations
+from typing import Optional
+import torch
+
+from aiter.ops.triton.utils.logger import AiterTritonLogger
+from aiter.ops.triton.conv._utils import (
+ BLOCK_K,
+ _out_hw,
+ _is_1x1_conv,
+ _is_3x3_conv,
+ _is_winograd_eligible,
+)
+from aiter.ops.triton.conv._prepack import (
+ get_or_make_weight_pack,
+ get_or_make_weight_pack_3x3,
+ get_or_make_input_pack_cblocked,
+ get_or_make_winograd_filter_f4x3,
+)
+from aiter.ops.triton.conv._launch import (
+ _launch_1x1,
+ _launch_3x3_nhwc,
+ _launch_3x3_cblocked,
+ _launch_general,
+ _launch_winograd_f4x3,
+ _launch_winograd_f4x3_cblocked,
+ _launch_winograd_f4x3_fused,
+ _select_3x3_method,
+)
+
+_LOGGER = AiterTritonLogger()
+
+# Tracks the last Triton kernel selected by conv2d_nchw smart routing.
+_last_triton_kernel: Optional[str] = None
+
+
+def conv2d(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ layout="nchw",
+):
+ """Forward 2-D conv on AMD ROCm via Triton. Drop-in for the forward of
+ ``torch.nn.functional.conv2d`` (no backward).
+
+ A shape-driven router picks among five kernel families (1x1, 3x3 cblocked,
+ 3x3 NHWC, Winograd F(4x4,3x3), general) per call.
+
+ Inputs must be fp16 or bf16. ``layout="nhwc"`` runs an NHWC-native kernel
+ with no internal layout conversion.
+
+ ``out_dtype=None`` (default) returns output in the input dtype, matching
+ ``torch.nn.Conv2d`` semantics.
+
+ Notes
+ -----
+ - Only ``groups=1`` (depthwise/grouped raises ``AssertionError``).
+ - Only ``padding_mode="zeros"`` (no reflect/replicate/circular).
+ - ``bias=None`` skips the with-bias kernel path; passing a zero tensor
+ instead routes through the with-bias kernel and times differently.
+ """
+ if x.dtype not in (torch.float16, torch.bfloat16):
+ raise ValueError(f"conv2d only supports fp16 and bf16 inputs, got {x.dtype}")
+ if out_dtype is None:
+ out_dtype = x.dtype
+ elif out_dtype not in (torch.float16, torch.bfloat16):
+ raise ValueError(
+ f"out_dtype must be torch.float16 or torch.bfloat16, got {out_dtype}"
+ )
+ layout = layout.lower()
+ if layout not in ("nchw", "nhwc"):
+ raise ValueError(f"layout must be 'nchw' or 'nhwc', got '{layout}'")
+
+ _LOGGER.info(
+ f"CONV2D: x={tuple(x.shape)} w={tuple(w_oihw.shape)} stride={stride} "
+ f"padding={padding} dilation={dilation} layout={layout} "
+ f"dtype={x.dtype} out_dtype={out_dtype} bias={'yes' if bias is not None else 'no'} "
+ f"act={activation}"
+ )
+
+ if layout == "nhwc":
+ return conv2d_nhwc(
+ x, w_oihw, bias, stride, padding, dilation, activation, out_dtype
+ )
+ else:
+ return conv2d_nchw(
+ x, w_oihw, bias, stride, padding, dilation, activation, out_dtype
+ )
+
+
+def conv2d_winograd_f4x3(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+ layout="nchw",
+):
+ """NCHW/NHWC conv2d using Winograd F(4x4,3x3). Raises ValueError for non-eligible convs."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if not _is_winograd_eligible(R, S, stride, dilation, C):
+ raise ValueError(
+ f"conv2d_winograd_f4x3 requires 3x3 kernel with stride=1, dilation=1, "
+ f"and C >= 4 (F(4,3) output transform amplifies rounding by up to "
+ f"361x; C<4 has too few reduction terms to absorb it), "
+ f"got {R}x{S} stride={stride} dilation={dilation} C={C}"
+ )
+
+ if layout == "nhwc":
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to(
+ memory_format=torch.channels_last
+ )
+ else:
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k)
+ _launch_winograd_f4x3(
+ x,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ layout=layout,
+ )
+ return y
+
+
+def conv2d_winograd_f4x3_cblocked(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """NCHW conv2d using Winograd F(4x4,3x3) with NCHWc input layout for coalesced loads.
+ Raises ValueError for non-eligible convs."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if not _is_winograd_eligible(R, S, stride, dilation, C):
+ raise ValueError(
+ f"conv2d_winograd_f4x3_cblocked requires 3x3 kernel with stride=1, dilation=1, "
+ f"and C >= 4 (F(4,3) output transform amplifies rounding by up to "
+ f"361x; C<4 has too few reduction terms to absorb it), "
+ f"got {R}x{S} stride={stride} dilation={dilation} C={C}"
+ )
+
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k)
+ x_blocked, C_pad_blocked = get_or_make_input_pack_cblocked(x, block_k)
+ _launch_winograd_f4x3_cblocked(
+ x_blocked,
+ C_pad_blocked,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ block_k,
+ )
+ return y
+
+
+def conv2d_winograd_f4x3_fused(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """NCHW conv2d using Winograd F(4x4,3x3) with fused GEMM+output transform.
+ Raises ValueError for non-eligible convs."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if not _is_winograd_eligible(R, S, stride, dilation, C):
+ raise ValueError(
+ f"conv2d_winograd_f4x3_fused requires 3x3 kernel with stride=1, dilation=1, "
+ f"and C >= 4 (F(4,3) output transform amplifies rounding by up to "
+ f"361x; C<4 has too few reduction terms to absorb it), "
+ f"got {R}x{S} stride={stride} dilation={dilation} C={C}"
+ )
+
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k)
+ _launch_winograd_f4x3_fused(
+ x,
+ U,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ padding,
+ activation,
+ )
+ return y
+
+
+def conv2d_1x1(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+ layout="nchw",
+):
+ """NCHW/NHWC conv2d for 1x1 kernels. Raises ValueError for non-1x1."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ if not _is_1x1_conv(R, S, dilation):
+ raise ValueError(f"conv2d_1x1 requires 1x1 kernel, got {R}x{S}")
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if layout == "nhwc":
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to(
+ memory_format=torch.channels_last
+ )
+ else:
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ _launch_1x1(
+ x,
+ w_oihw.contiguous(),
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ stride,
+ padding,
+ activation,
+ layout=layout,
+ )
+ return y
+
+
+def conv2d_general(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+ layout="nchw",
+):
+ """NCHW/NHWC conv2d using general kernel with prepacked weights (5x5, 7x7, etc.)."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if layout == "nhwc":
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to(
+ memory_format=torch.channels_last
+ )
+ else:
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ w_k, (_, K_pad) = get_or_make_weight_pack(w_oihw.contiguous(), block_k)
+ _launch_general(
+ x,
+ w_k,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ R,
+ S,
+ P,
+ Q,
+ K_pad,
+ stride,
+ padding,
+ dilation,
+ block_k,
+ activation,
+ layout=layout,
+ )
+ return y
+
+
+def conv2d_nhwc_3x3(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """NHWC conv2d for 3x3 kernels. Raises ValueError for non-3x3."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ if not _is_3x3_conv(R, S):
+ raise ValueError(f"conv2d_nhwc_3x3 requires 3x3 kernel, got {R}x{S}")
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to(
+ memory_format=torch.channels_last
+ )
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ w_3x3, (_, C_pad) = get_or_make_weight_pack_3x3(w_oihw.contiguous(), block_k)
+ _launch_3x3_nhwc(
+ x,
+ w_3x3,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ stride,
+ padding,
+ dilation,
+ activation,
+ )
+ return y
+
+
+def conv2d_nchw(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """Hybrid NCHW conv2d: routes to specialized 1x1, 3x3, or general kernel."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+
+ global _last_triton_kernel
+ if _is_1x1_conv(R, S, dilation):
+ _last_triton_kernel = "_conv2d_1x1_kernel"
+ return conv2d_1x1(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ layout="nchw",
+ )
+ elif _is_3x3_conv(R, S):
+ method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation)
+ if method == "winograd_f4x3_cblocked":
+ _last_triton_kernel = "_winograd_f4x3_cblocked_* (3 kern)"
+ return conv2d_winograd_f4x3_cblocked(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ )
+ elif method == "winograd_f4x3":
+ _last_triton_kernel = "_winograd_f4x3_* (3 kernels)"
+ return conv2d_winograd_f4x3(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ )
+ else:
+ _last_triton_kernel = "_conv2d_3x3_cblocked_kernel"
+ return conv2d_nchw_cblocked(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ )
+ else:
+ _last_triton_kernel = "_conv2d_general_kernel"
+ return conv2d_general(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ layout="nchw",
+ )
+
+
+def conv2d_nhwc(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """Conv2d with NHWC (channels-last) input and output.
+
+ Input x can be NCHW or NHWC — it will be converted to channels_last.
+ Output y is allocated as channels_last (NHWC-contiguous) and returned
+ in logical NCHW shape with channels_last strides.
+ """
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ x = x.to(memory_format=torch.channels_last)
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+
+ global _last_triton_kernel
+ if _is_1x1_conv(R, S, dilation):
+ _last_triton_kernel = "_conv2d_1x1_kernel"
+ return conv2d_1x1(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ layout="nhwc",
+ )
+ elif _is_3x3_conv(R, S):
+ method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation)
+ if method in ("winograd_f4x3", "winograd_f4x3_cblocked"):
+ _last_triton_kernel = "_winograd_f4x3_* (3 kernels)"
+ return conv2d_winograd_f4x3(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ layout="nhwc",
+ )
+ else:
+ _last_triton_kernel = "_conv2d_3x3_nhwc_kernel"
+ return conv2d_nhwc_3x3(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ )
+ else:
+ _last_triton_kernel = "_conv2d_general_kernel"
+ return conv2d_general(
+ x,
+ w_oihw,
+ bias,
+ stride,
+ padding,
+ dilation,
+ activation,
+ out_dtype,
+ block_k,
+ layout="nhwc",
+ )
+
+
+def conv2d_nchw_cblocked(
+ x,
+ w_oihw,
+ bias=None,
+ stride=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ activation="none",
+ out_dtype: Optional[torch.dtype] = None,
+ block_k=BLOCK_K,
+):
+ """NCHW conv2d with channel-blocked input packing for 3x3 kernels.
+ Raises ValueError for non-3x3."""
+ assert x.is_cuda and w_oihw.is_cuda
+ if out_dtype is None:
+ out_dtype = x.dtype
+ N, C, H, W_in = x.shape
+ K_out, Cw, R, S = w_oihw.shape
+ assert Cw == C
+ P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation)
+
+ if not _is_3x3_conv(R, S):
+ raise ValueError(f"conv2d_nchw_cblocked requires 3x3 kernel, got {R}x{S}")
+
+ y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype)
+ bias_fp32 = bias.float().contiguous() if bias is not None else None
+ w_3x3, (_, C_pad) = get_or_make_weight_pack_3x3(w_oihw.contiguous(), block_k)
+ Cb = block_k # packing block size matches weight padding block
+ x_blocked, C_pad_x = get_or_make_input_pack_cblocked(x, Cb)
+ # Ensure channel padding is consistent
+ assert (
+ C_pad_x == C_pad
+ ), f"Channel padding mismatch: input {C_pad_x} vs weight {C_pad}"
+ _launch_3x3_cblocked(
+ x_blocked,
+ w_3x3,
+ bias_fp32,
+ y,
+ N,
+ C,
+ H,
+ W_in,
+ K_out,
+ P,
+ Q,
+ C_pad,
+ Cb,
+ stride,
+ padding,
+ dilation,
+ activation,
+ )
+ return y
diff --git a/op_tests/op_benchmarks/triton/bench_conv2d.py b/op_tests/op_benchmarks/triton/bench_conv2d.py
new file mode 100644
index 0000000000..10cd8254ef
--- /dev/null
+++ b/op_tests/op_benchmarks/triton/bench_conv2d.py
@@ -0,0 +1,889 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+"""Benchmark aiter.ops.triton.conv.conv2d.
+
+Two modes:
+- Single-shape (pass --N --C --H --W --K --R --S [--stride ...] etc.)
+ Bench one shape and emit a single key=value result line. Useful for
+ ad-hoc one-off measurements or scripting around a specific shape.
+- Sweep (no --N). Iterates either the built-in default shape list or the
+ conv2d shapes for a model in model_shapes.json (--model NAME), and
+ prints three box-drawn tables at the end:
+ 1. LAYER-BY-LAYER BENCHMARK (per-layer Tri vs Torch + correctness)
+ 2. MIOpen SOLVER SUMMARY (only when --miopen-solvers is passed)
+ 3. OVERALL PERFORMANCE (mean/median/aggregate TFLOPS, layer wins)
+
+Each shape is timed with triton.testing.do_bench against
+torch.nn.functional.conv2d as the reference backend (MIOpen on AMD),
+and a correctness check compares the Triton output against F.conv2d
+within the same tolerance model the test suite uses.
+
+The selected Triton kernel name is captured for every shape. The MIOpen
+solver name is captured only when --miopen-solvers is passed, because
+detection requires a separate subprocess with MIOPEN_LOG_LEVEL=6 (~60s
+fixed startup cost).
+
+For NCHW non-1x1 shapes, kernel+repack timing is also captured: the input
+prepack cache (_PACK_CACHE_CBLOCKED) is cleared before each timing call,
+giving the steady-state inference cost when input layout changes per call.
+
+No model loading at runtime — model shapes come from the pre-extracted
+model_shapes.json (see extract_conv_shapes.py for how to regenerate it).
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import re
+import statistics
+import subprocess
+import sys
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+import triton
+
+import aiter.ops.triton.conv.conv2d as _ops_module
+from aiter.ops.triton.conv._utils import (
+ flops_conv,
+ _out_hw,
+ _is_1x1_conv,
+ _is_3x3_conv,
+ dynamic_conv_tolerances,
+ _winograd_tolerances,
+)
+from aiter.ops.triton.conv._prepack import _PACK_CACHE_CBLOCKED
+from aiter.ops.triton.conv.conv2d import (
+ conv2d,
+ conv2d_nchw,
+ conv2d_nchw_cblocked,
+ conv2d_nhwc,
+ conv2d_winograd_f4x3,
+ conv2d_winograd_f4x3_fused,
+ conv2d_winograd_f4x3_cblocked,
+)
+
+METHODS = {
+ "auto": conv2d,
+ "default": conv2d_nchw,
+ "cblocked": conv2d_nchw_cblocked,
+ "nhwc": conv2d_nhwc,
+ "winograd_f4x3": conv2d_winograd_f4x3,
+ "winograd_f4x3_fused": conv2d_winograd_f4x3_fused,
+ "winograd_f4x3_cblocked": conv2d_winograd_f4x3_cblocked,
+}
+
+
+# Default sweep shapes — same set as the test edge cases. Kept in sync by hand
+# (12 tuples; trivial duplication, see _helpers.get_edge_case_shapes).
+DEFAULT_SHAPES = [
+ # (N, C, H, W, K, R, S, stride, padding, dilation, desc)
+ (1, 3, 7, 7, 8, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 same padding"),
+ (1, 3, 8, 8, 16, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 stride1"),
+ (2, 16, 32, 32, 32, 3, 3, (2, 2), (1, 1), (1, 1), "stride2"),
+ (2, 32, 17, 23, 64, 5, 5, (2, 2), (2, 2), (1, 1), "odd dims + pad"),
+ (4, 64, 28, 28, 128, 3, 3, (1, 1), (0, 0), (2, 2), "dilation2"),
+ (2, 512, 7, 7, 1024, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 large channels"),
+ (1, 3, 112, 112, 64, 7, 7, (2, 2), (3, 3), (1, 1), "7x7 large spatial"),
+ (1, 1, 16, 16, 16, 3, 3, (1, 1), (1, 1), (1, 1), "single input channel"),
+ (2, 64, 8, 8, 64, 3, 3, (1, 1), (1, 1), (1, 1), "small spatial 3x3"),
+ (1, 128, 4, 4, 256, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 tiny spatial"),
+ (2, 32, 32, 32, 32, 3, 3, (1, 1), (0, 0), (1, 1), "3x3 no padding"),
+ (2, 64, 28, 28, 128, 3, 3, (2, 2), (1, 1), (1, 1), "3x3 stride2 standard"),
+]
+
+
+# MIOpen solver names → human-readable algorithm types (matches old suite.py).
+MIOPEN_ALGO_MAP = {
+ "ConvWinoFuryRxS<2-3>": "Winograd Fury F(2,3)",
+ "ConvBinWinogradRxSf3x2": "Winograd F(3x3,2x2) binary",
+ "GemmFwd1x1_0_1": "GEMM (no workspace)",
+ "GemmFwdRest": "GEMM fallback",
+}
+
+
+# ----------------------------------------------------------------------------
+# Helpers
+# ----------------------------------------------------------------------------
+
+
+def _torch_dtype(s: str) -> torch.dtype:
+ if s == "fp16":
+ return torch.float16
+ if s == "bf16":
+ return torch.bfloat16
+ raise ValueError(f"unsupported dtype: {s}")
+
+
+def _check_close(got, ref, dtype, K_red, is_winograd: bool) -> bool:
+ """Tolerance-aware correctness check (same model as the pytest suite)."""
+ if is_winograd:
+ rtol, atol = _winograd_tolerances(dtype, K_red, ref, "f4x3")
+ else:
+ rtol, atol = dynamic_conv_tolerances(dtype, K_red, ref)
+ try:
+ torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol)
+ return True
+ except AssertionError:
+ return False
+
+
+def _kernel_type_tag(R: int, S: int, dilation: tuple) -> str:
+ if _is_1x1_conv(R, S, dilation):
+ return "[1x1]"
+ if _is_3x3_conv(R, S):
+ return "[3x3]"
+ return "[general]"
+
+
+def _shape_str(N, C, H, W, K, R, S) -> str:
+ return f"({N},{C},{H},{W})→{K}/{R}x{S}"
+
+
+# ----------------------------------------------------------------------------
+# MIOpen solver detection (subprocess-based, opt-in via --miopen-solvers)
+# ----------------------------------------------------------------------------
+
+
+_miopen_solver_cache: dict = {}
+
+
+def precompute_miopen_solvers(shapes, dtype: torch.dtype) -> None:
+ """Detect MIOpen solver per shape via a single subprocess.
+
+ Spawns Python with MIOPEN_LOG_LEVEL=6, runs F.conv2d for each shape,
+ and parses stderr for "Chosen Algorithm:" lines. SHAPE_DONE markers
+ on stderr disambiguate which "Chosen Algorithm" line belongs to which
+ shape (positional alignment is unreliable when MIOpen logs vary).
+
+ Cache populated as a side effect. Use _get_miopen_solver to read.
+ """
+ global _miopen_solver_cache
+
+ unique = []
+ seen = set()
+ for entry in shapes:
+ N, C, H, W, K, R, S, stride, padding, dilation = entry[:10]
+ s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride)
+ p_h, p_w = padding if isinstance(padding, tuple) else (padding, padding)
+ d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation)
+ key = (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w)
+ if key not in seen:
+ seen.add(key)
+ unique.append(key)
+ if not unique:
+ return
+
+ dtype_str = {
+ torch.float16: "torch.float16",
+ torch.bfloat16: "torch.bfloat16",
+ }.get(dtype, "torch.float16")
+
+ lines = [
+ "import os, sys",
+ "os.environ['MIOPEN_LOG_LEVEL']='6'",
+ "import torch, torch.nn.functional as F",
+ ]
+ for i, (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w) in enumerate(unique):
+ lines.append(f"# shape {i}")
+ lines.append(f"x=torch.randn({N},{C},{H},{W},device='cuda',dtype={dtype_str})")
+ lines.append(f"w=torch.randn({K},{C},{R},{S},device='cuda',dtype={dtype_str})")
+ lines.append(
+ f"F.conv2d(x,w,None,stride=({s_h},{s_w}),padding=({p_h},{p_w}),dilation=({d_h},{d_w}))"
+ )
+ lines.append("torch.cuda.synchronize()")
+ lines.append(f"sys.stderr.write('SHAPE_DONE:{i}\\n');sys.stderr.flush()")
+ script = "\n".join(lines)
+
+ try:
+ result = subprocess.run(
+ [sys.executable, "-c", script],
+ capture_output=True,
+ text=True,
+ timeout=120,
+ env={**os.environ, "MIOPEN_LOG_LEVEL": "6"},
+ )
+ except subprocess.TimeoutExpired:
+ print(
+ f"[miopen-detect] WARNING: subprocess timed out after 120s; "
+ f"MIOpen solver column will be empty for {len(unique)} shape(s).",
+ file=sys.stderr,
+ )
+ return
+ except Exception as e:
+ print(
+ f"[miopen-detect] WARNING: subprocess failed ({e!r}); "
+ f"MIOpen solver column will be empty.",
+ file=sys.stderr,
+ )
+ return
+
+ if result.returncode != 0:
+ tail = "\n".join(result.stderr.strip().split("\n")[-5:])
+ print(
+ f"[miopen-detect] WARNING: subprocess exited with code "
+ f"{result.returncode}; MIOpen solver column will be empty.\n"
+ f" Last stderr lines:\n{tail}",
+ file=sys.stderr,
+ )
+ return
+
+ pending: Optional[str] = None
+ attributed: dict = {}
+ orphan = 0
+ shape_done_re = re.compile(r"^SHAPE_DONE:(\d+)\s*$")
+ chosen_re = re.compile(r"Chosen Algorithm:\s*(\S+)")
+ for line in result.stderr.split("\n"):
+ m = chosen_re.search(line)
+ if m:
+ pending = m.group(1).strip(" ,")
+ continue
+ m = shape_done_re.match(line)
+ if m:
+ idx = int(m.group(1))
+ if pending is not None:
+ attributed[idx] = pending
+ else:
+ orphan += 1
+ pending = None
+ for idx, solver in attributed.items():
+ _miopen_solver_cache[unique[idx]] = solver
+
+ missing = len(unique) - len(attributed)
+ if missing > 0:
+ print(
+ f"[miopen-detect] WARNING: {missing}/{len(unique)} shape(s) have no "
+ f"MIOpen solver detected ({orphan} marker(s) had no preceding "
+ f"'Chosen Algorithm' line). Common causes: MIOpen log format changed, "
+ f"MIOPEN_LOG_LEVEL was overridden, or the shape failed in the subprocess.",
+ file=sys.stderr,
+ )
+
+
+def _get_miopen_solver(N, C, H, W, K, R, S, stride, padding, dilation) -> str:
+ s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride)
+ p_h, p_w = padding if isinstance(padding, tuple) else (padding, padding)
+ d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation)
+ return _miopen_solver_cache.get(
+ (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w), ""
+ )
+
+
+# ----------------------------------------------------------------------------
+# Per-shape bench (returns rich dict)
+# ----------------------------------------------------------------------------
+
+
+def bench_one_shape(
+ N: int,
+ C: int,
+ H: int,
+ W: int,
+ K: int,
+ R: int,
+ S: int,
+ stride: tuple,
+ padding: tuple,
+ dilation: tuple,
+ dtype: torch.dtype,
+ method: str,
+ layout: str,
+ bias: bool = True,
+ measure_repack: bool = True,
+) -> dict:
+ """Time + correctness-check one shape. Returns a dict with full metadata.
+
+ Keys: ms_tri, ms_torch, ms_tri_e2e (or None), tflops_tri, tflops_torch,
+ tflops_tri_e2e (or None), correct, kernel_name, has_repack, flops.
+
+ measure_repack: if True (default), additionally times the kernel+input-repack
+ path for NCHW non-1x1 shapes by clearing _PACK_CACHE_CBLOCKED between calls.
+ """
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA not available; conv2d bench requires a GPU.")
+
+ P, Q = _out_hw(H, W, R, S, stride, padding, dilation)
+ if P < 1 or Q < 1:
+ raise ValueError(
+ f"output spatial dims < 1 for shape "
+ f"N={N} C={C} H={H} W={W} K={K} R={R} S={S} "
+ f"stride={stride} padding={padding} dilation={dilation}"
+ )
+
+ device = "cuda"
+ x = torch.randn((N, C, H, W), device=device, dtype=dtype)
+ w = torch.randn((K, C, R, S), device=device, dtype=dtype)
+ b = torch.randn((K,), device=device, dtype=dtype) if bias else None
+
+ if layout == "nhwc":
+ x_in = x.to(memory_format=torch.channels_last)
+ kernel_fn = conv2d_nhwc
+ else:
+ x_in = x
+ if method not in METHODS:
+ raise ValueError(f"unknown method: {method}; choices: {list(METHODS)}")
+ kernel_fn = METHODS[method]
+
+ def run_triton():
+ return kernel_fn(
+ x_in,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ activation="none",
+ out_dtype=dtype,
+ )
+
+ def run_torch():
+ return F.conv2d(x_in, w, b, stride=stride, padding=padding, dilation=dilation)
+
+ # One run: captures the output (for correctness) AND _last_triton_kernel.
+ y_tri = run_triton()
+ torch.cuda.synchronize()
+ kernel_name = getattr(_ops_module, "_last_triton_kernel", "") or ""
+ is_winograd = "winograd" in kernel_name.lower() or "wino" in kernel_name.lower()
+
+ y_ref = run_torch()
+ correct = _check_close(
+ y_tri, y_ref, dtype, K_red=C * R * S, is_winograd=is_winograd
+ )
+
+ ms_tri = triton.testing.do_bench(run_triton, warmup=15, rep=50)
+ ms_th = triton.testing.do_bench(run_torch, warmup=15, rep=50)
+
+ # Kernel+repack timing: clear input pack cache between calls. Only NCHW
+ # non-1x1 — 1x1 takes raw weights (no repacking) and NHWC has its own
+ # path that doesn't use _PACK_CACHE_CBLOCKED.
+ has_repack = (
+ measure_repack and layout != "nhwc" and not _is_1x1_conv(R, S, dilation)
+ )
+ if has_repack:
+
+ def run_triton_e2e():
+ _PACK_CACHE_CBLOCKED.clear()
+ return run_triton()
+
+ ms_tri_e2e = triton.testing.do_bench(run_triton_e2e, warmup=15, rep=50)
+ else:
+ ms_tri_e2e = None
+
+ flops = flops_conv(N, C, K, R, S, P, Q)
+ tflops_tri = flops / (ms_tri * 1e-3) / 1e12
+ tflops_th = flops / (ms_th * 1e-3) / 1e12
+ tflops_tri_e2e = flops / (ms_tri_e2e * 1e-3) / 1e12 if ms_tri_e2e else None
+
+ return {
+ "ms_tri": ms_tri,
+ "ms_torch": ms_th,
+ "ms_tri_e2e": ms_tri_e2e,
+ "tflops_tri": tflops_tri,
+ "tflops_torch": tflops_th,
+ "tflops_tri_e2e": tflops_tri_e2e,
+ "correct": correct,
+ "kernel_name": kernel_name,
+ "has_repack": has_repack,
+ "flops": flops,
+ }
+
+
+# ----------------------------------------------------------------------------
+# Single-shape mode (used by bench_models.py)
+# ----------------------------------------------------------------------------
+
+
+def _format_single_shape_line(args, result: dict) -> str:
+ """Single-line key=value output for bench_models.py to parse.
+
+ Last whitespace-separated token is the primary metric value.
+ """
+ primary = result["ms_tri"] if args.metric == "time" else result["tflops_tri"]
+ parts = [
+ f"N={args.N}",
+ f"C={args.C}",
+ f"H={args.H}",
+ f"W={args.W}",
+ f"K={args.K}",
+ f"R={args.R}",
+ f"S={args.S}",
+ f"method={args.method}",
+ f"layout={args.layout}",
+ f"ms_tri={result['ms_tri']:.4f}",
+ f"ms_torch={result['ms_torch']:.4f}",
+ f"tflops_tri={result['tflops_tri']:.4f}",
+ f"tflops_torch={result['tflops_torch']:.4f}",
+ f"correct={int(result['correct'])}",
+ ]
+ if args.show_kernel_name:
+ parts.append(f"kernel={result['kernel_name'] or 'unknown'}")
+ parts.append(f"{primary:.4f}")
+ return " ".join(parts)
+
+
+def run_single_shape(args) -> None:
+ dtype = _torch_dtype(args.dtype)
+ stride = (args.stride_h, args.stride_w)
+ padding = (args.pad_h, args.pad_w)
+ dilation = (args.dilation_h, args.dilation_w)
+ # Single-shape mode (bench_models.py consumer): skip kernel+repack timing
+ # to keep per-call cost predictable for the framework.
+ result = bench_one_shape(
+ args.N,
+ args.C,
+ args.H,
+ args.W,
+ args.K,
+ args.R,
+ args.S,
+ stride,
+ padding,
+ dilation,
+ dtype,
+ args.method,
+ args.layout,
+ bias=not args.no_bias,
+ measure_repack=False,
+ )
+ print(_format_single_shape_line(args, result))
+
+
+# ----------------------------------------------------------------------------
+# Box-drawn table printers
+# ----------------------------------------------------------------------------
+
+
+def _box_table(headers, rows, align: Optional[list] = None) -> str:
+ """Render a list of header-string + row-tuples into a box-drawn table.
+
+ align: per-column alignment, "l" (left, default) or "r" (right).
+ """
+ n = len(headers)
+ if align is None:
+ align = ["l"] * n
+ widths = [
+ max(len(headers[j]), max((len(str(r[j])) for r in rows), default=0))
+ for j in range(n)
+ ]
+
+ def fmt_row(vals):
+ cells = []
+ for j, v in enumerate(vals):
+ s = str(v)
+ if align[j] == "r":
+ cells.append(f" {s:>{widths[j]}} ")
+ else:
+ cells.append(f" {s:<{widths[j]}} ")
+ return "│" + "│".join(cells) + "│"
+
+ sep_top = "┌" + "┬".join("─" * (w + 2) for w in widths) + "┐"
+ sep_mid = "├" + "┼".join("─" * (w + 2) for w in widths) + "┤"
+ sep_bot = "└" + "┴".join("─" * (w + 2) for w in widths) + "┘"
+
+ lines = [sep_top, fmt_row(headers), sep_mid]
+ for i, row in enumerate(rows):
+ lines.append(fmt_row(row))
+ if i < len(rows) - 1:
+ lines.append(sep_mid)
+ lines.append(sep_bot)
+ return "\n".join(lines)
+
+
+def _print_layer_table(
+ layers: list, has_any_repack: bool, miopen_enabled: bool
+) -> None:
+ print("\n" + "=" * 80)
+ print("LAYER-BY-LAYER BENCHMARK")
+ print("=" * 80)
+
+ headers = ["#", "Layer", "Type", "Shape"]
+ if miopen_enabled:
+ headers.append("MIOpen Solver")
+ headers.append("Triton Kernel")
+ headers.append("Tri Kernel TF/s")
+ if has_any_repack:
+ headers.append("Tri Kernel+Repack TF/s")
+ headers.extend(["Torch TF/s", "Winner"])
+
+ rows = []
+ for i, lr in enumerate(layers):
+ row = [str(i), lr["name"], lr["type"], lr["shape"]]
+ if miopen_enabled:
+ row.append(lr["miopen_solver"] or "—")
+ row.append(lr["kernel_name"] or "—")
+ row.append(f"{lr['tflops_tri']:.2f}")
+ if has_any_repack:
+ row.append(
+ f"{lr['tflops_tri_e2e']:.2f}"
+ if lr["tflops_tri_e2e"] is not None
+ else "—"
+ )
+ row.append(f"{lr['tflops_torch']:.2f}")
+ # Winner uses kernel TF/s (not kernel+repack) for consistency with old code.
+ row.append("Triton" if lr["tflops_tri"] > lr["tflops_torch"] else "Torch")
+ rows.append(row)
+
+ print(_box_table(headers, rows))
+
+
+def _print_miopen_solver_table(layers: list) -> None:
+ """Group layers by MIOpen solver, print one row per solver."""
+ from collections import OrderedDict
+
+ solver_layers: dict = OrderedDict()
+ for i, lr in enumerate(layers):
+ s = lr.get("miopen_solver") or "unknown"
+ solver_layers.setdefault(s, []).append(f"L{i}")
+
+ if not any(s != "unknown" for s in solver_layers):
+ return # Nothing detected; skip the table entirely.
+
+ print("\n" + "=" * 80)
+ print("MIOpen SOLVER SUMMARY")
+ print("=" * 80)
+
+ rows = []
+ for solver, ls in solver_layers.items():
+ algo = MIOPEN_ALGO_MAP.get(solver, solver)
+ layer_str = ", ".join(ls)
+ if len(layer_str) > 80:
+ layer_str = ", ".join(ls[:10]) + f" ... ({len(ls)} layers total)"
+ rows.append([solver, algo, layer_str])
+ print(_box_table(("MIOpen Solver", "Algorithm Type", "Used For"), rows))
+
+
+def _print_overall_perf_table(layers: list, has_any_repack: bool) -> None:
+ """Mean/median/aggregate TFLOPS, total time, layer wins."""
+ print("\n" + "=" * 80)
+ print("OVERALL PERFORMANCE")
+ print("=" * 80)
+
+ tri_tf = [lr["tflops_tri"] for lr in layers]
+ th_tf = [lr["tflops_torch"] for lr in layers]
+ tri_ms = [lr["ms_tri"] for lr in layers]
+ th_ms = [lr["ms_torch"] for lr in layers]
+
+ # Aggregate = sum(flops) / sum(time)
+ sum_flops = sum(lr["flops"] for lr in layers)
+ sum_time_tri = sum(lr["ms_tri"] * 1e-3 for lr in layers)
+ sum_time_th = sum(lr["ms_torch"] * 1e-3 for lr in layers)
+ agg_tri = sum_flops / sum_time_tri / 1e12 if sum_time_tri else 0.0
+ agg_th = sum_flops / sum_time_th / 1e12 if sum_time_th else 0.0
+
+ n = len(layers)
+ tri_wins = sum(1 for lr in layers if lr["tflops_tri"] > lr["tflops_torch"])
+
+ rows = [
+ [
+ "Mean TFLOPS (kernel)",
+ f"{statistics.mean(tri_tf):.2f}",
+ f"{statistics.mean(th_tf):.2f}",
+ ],
+ ]
+ if has_any_repack:
+ e2e_tf = [
+ (
+ lr["tflops_tri_e2e"]
+ if lr["tflops_tri_e2e"] is not None
+ else lr["tflops_tri"]
+ )
+ for lr in layers
+ ]
+ e2e_ms = [
+ lr["ms_tri_e2e"] if lr["ms_tri_e2e"] is not None else lr["ms_tri"]
+ for lr in layers
+ ]
+ sum_time_e2e = sum(t * 1e-3 for t in e2e_ms)
+ agg_tri_e2e = sum_flops / sum_time_e2e / 1e12 if sum_time_e2e else 0.0
+ e2e_wins = sum(1 for lr, ee in zip(layers, e2e_tf) if ee > lr["tflops_torch"])
+ rows.append(
+ [
+ "Mean TFLOPS (kernel+repack)",
+ f"{statistics.mean(e2e_tf):.2f}",
+ f"{statistics.mean(th_tf):.2f}",
+ ]
+ )
+ rows.append(
+ [
+ "Median TFLOPS (kernel)",
+ f"{statistics.median(tri_tf):.2f}",
+ f"{statistics.median(th_tf):.2f}",
+ ]
+ )
+ if has_any_repack:
+ rows.append(
+ [
+ "Median TFLOPS (kernel+repack)",
+ f"{statistics.median(e2e_tf):.2f}",
+ f"{statistics.median(th_tf):.2f}",
+ ]
+ )
+ rows.append(["Aggregate TFLOPS (kernel)", f"{agg_tri:.2f}", f"{agg_th:.2f}"])
+ if has_any_repack:
+ rows.append(
+ ["Aggregate TFLOPS (kernel+repack)", f"{agg_tri_e2e:.2f}", f"{agg_th:.2f}"]
+ )
+ rows.append(["Total kernel time (ms)", f"{sum(tri_ms):.2f}", f"{sum(th_ms):.2f}"])
+ if has_any_repack:
+ rows.append(
+ ["Total kernel+repack time (ms)", f"{sum(e2e_ms):.2f}", f"{sum(th_ms):.2f}"]
+ )
+ rows.append(["Layer wins (kernel)", f"{tri_wins}/{n}", f"{n - tri_wins}/{n}"])
+ if has_any_repack:
+ rows.append(
+ ["Layer wins (kernel+repack)", f"{e2e_wins}/{n}", f"{n - e2e_wins}/{n}"]
+ )
+ rows.append(
+ ["Correctness", f"{sum(1 for lr in layers if lr['correct'])}/{n} passed", "—"]
+ )
+
+ print(_box_table(("Metric", "Triton", "PyTorch (MIOpen)"), rows))
+
+
+# ----------------------------------------------------------------------------
+# Sweep mode
+# ----------------------------------------------------------------------------
+
+
+_MODEL_SHAPES_PATH = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ "model_benchmarking_tool",
+ "model_shapes.json",
+)
+
+
+def _load_model_shapes(model_pattern: str) -> tuple[str, list]:
+ """Load conv2d shapes for a model from model_shapes.json.
+
+ model_pattern: case-insensitive substring matched against model keys.
+ Returns (matched_model_name, list of shape tuples in the same form as
+ DEFAULT_SHAPES — desc is " L").
+ """
+ with open(_MODEL_SHAPES_PATH) as f:
+ data = json.load(f)
+
+ matches = [
+ m for m in data if model_pattern.lower() in m.lower() and "conv2d" in data[m]
+ ]
+ if not matches:
+ avail = sorted(m for m, k in data.items() if "conv2d" in k)
+ raise ValueError(
+ f"No model with 'conv2d' shapes matches {model_pattern!r}. "
+ f"Available: {avail}"
+ )
+ if len(matches) > 1:
+ raise ValueError(
+ f"Pattern {model_pattern!r} matched multiple models: {matches}. "
+ f"Use a more specific pattern."
+ )
+
+ model = matches[0]
+ shapes = []
+ for i, s in enumerate(data[model]["conv2d"]):
+ shapes.append(
+ (
+ s["N"],
+ s["C"],
+ s["H"],
+ s["W"],
+ s["K"],
+ s["R"],
+ s["S"],
+ (s.get("stride_h", 1), s.get("stride_w", 1)),
+ (s.get("pad_h", 0), s.get("pad_w", 0)),
+ (s.get("dilation_h", 1), s.get("dilation_w", 1)),
+ f"{model} L{i}",
+ )
+ )
+ return model, shapes
+
+
+def run_sweep(args) -> None:
+ """Iterate the chosen shape set, then print three summary tables."""
+ dtype = _torch_dtype(args.dtype)
+
+ if args.model:
+ try:
+ model, shapes = _load_model_shapes(args.model)
+ print(
+ f"# Sweep source: model_shapes.json :: {model} ({len(shapes)} layers)"
+ )
+ except (FileNotFoundError, ValueError) as e:
+ print(f"ERROR: {e}", file=sys.stderr)
+ sys.exit(1)
+ else:
+ shapes = DEFAULT_SHAPES
+ print(f"# Sweep source: built-in DEFAULT_SHAPES ({len(shapes)} shapes)")
+
+ print(
+ f"# dtype={args.dtype} method={args.method} layout={args.layout} "
+ f"miopen_solvers={'on' if args.miopen_solvers else 'off'}"
+ )
+
+ # Optional MIOpen solver detection (subprocess; ~60-120s startup)
+ if args.miopen_solvers:
+ print("# Detecting MIOpen solvers (subprocess; this can take a minute)...")
+ precompute_miopen_solvers(shapes, dtype)
+ print("# MIOpen solver detection complete.")
+
+ # Bench each shape, collect rows.
+ layers = []
+ for entry in shapes:
+ N, C, H, W, K, R, S, stride, padding, dilation, name = entry
+ try:
+ r = bench_one_shape(
+ N,
+ C,
+ H,
+ W,
+ K,
+ R,
+ S,
+ stride,
+ padding,
+ dilation,
+ dtype,
+ args.method,
+ args.layout,
+ bias=not args.no_bias,
+ measure_repack=True,
+ )
+ except Exception as e:
+ print(f" {name:<24} ERROR: {type(e).__name__}: {e}", file=sys.stderr)
+ continue
+ miopen = (
+ _get_miopen_solver(N, C, H, W, K, R, S, stride, padding, dilation)
+ if args.miopen_solvers
+ else ""
+ )
+ layers.append(
+ {
+ "name": name,
+ "type": _kernel_type_tag(R, S, dilation),
+ "shape": _shape_str(N, C, H, W, K, R, S),
+ "kernel_name": r["kernel_name"],
+ "miopen_solver": miopen,
+ "tflops_tri": r["tflops_tri"],
+ "tflops_tri_e2e": r["tflops_tri_e2e"],
+ "tflops_torch": r["tflops_torch"],
+ "ms_tri": r["ms_tri"],
+ "ms_tri_e2e": r["ms_tri_e2e"],
+ "ms_torch": r["ms_torch"],
+ "correct": r["correct"],
+ "flops": r["flops"],
+ }
+ )
+
+ if not layers:
+ print("No layers benched (all errored?).", file=sys.stderr)
+ return
+
+ has_any_repack = any(lr["ms_tri_e2e"] is not None for lr in layers)
+
+ _print_layer_table(layers, has_any_repack, miopen_enabled=args.miopen_solvers)
+ if args.miopen_solvers:
+ _print_miopen_solver_table(layers)
+ _print_overall_perf_table(layers, has_any_repack)
+
+
+# ----------------------------------------------------------------------------
+# CLI
+# ----------------------------------------------------------------------------
+
+
+def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
+ p = argparse.ArgumentParser(
+ prog="bench_conv2d",
+ description="Benchmark aiter.ops.triton.conv.conv2d (single shape or sweep).",
+ allow_abbrev=False,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ p.add_argument(
+ "--dtype",
+ "--conv_dtype",
+ "--conv-dtype",
+ choices=["fp16", "bf16"],
+ default="fp16",
+ )
+ p.add_argument(
+ "--method",
+ choices=list(METHODS.keys()),
+ default="auto",
+ help="kernel to bench. 'auto' uses the conv2d router.",
+ )
+ p.add_argument(
+ "--layout",
+ "--conv_layout",
+ "--conv-layout",
+ choices=["nchw", "nhwc"],
+ default="nchw",
+ )
+ p.add_argument("--metric", choices=["time", "throughput"], default="throughput")
+ p.add_argument(
+ "--no-bias",
+ "--no_bias",
+ action="store_true",
+ help="bench the bias=None code path",
+ )
+ p.add_argument(
+ "--show-kernel-name",
+ "--show_kernel_name",
+ action="store_true",
+ help="include the routed Triton kernel name in single-shape output",
+ )
+ p.add_argument(
+ "--miopen-solvers",
+ "--miopen_solvers",
+ action="store_true",
+ help="detect MIOpen solver names via a subprocess (sweep mode only; "
+ "adds ~60-120s upfront cost)",
+ )
+ p.add_argument(
+ "--model",
+ type=str,
+ default=None,
+ help="sweep mode: load conv2d shapes for this model from "
+ "model_shapes.json (case-insensitive substring match). If omitted, "
+ "the built-in DEFAULT_SHAPES list is used.",
+ )
+
+ # Single-shape mode (used by bench_models.py and one-off measurements).
+ p.add_argument("--N", type=int, default=None)
+ p.add_argument("--C", type=int, default=None)
+ p.add_argument("--H", type=int, default=None)
+ p.add_argument("--W", type=int, default=None)
+ p.add_argument("--K", type=int, default=None)
+ p.add_argument("--R", type=int, default=None)
+ p.add_argument("--S", type=int, default=None)
+ p.add_argument("--stride-h", "--stride_h", type=int, default=1)
+ p.add_argument("--stride-w", "--stride_w", type=int, default=1)
+ p.add_argument("--pad-h", "--pad_h", type=int, default=0)
+ p.add_argument("--pad-w", "--pad_w", type=int, default=0)
+ p.add_argument("--dilation-h", "--dilation_h", type=int, default=1)
+ p.add_argument("--dilation-w", "--dilation_w", type=int, default=1)
+
+ args = p.parse_args(argv)
+
+ single = [args.N, args.C, args.H, args.W, args.K, args.R, args.S]
+ if any(v is not None for v in single):
+ if any(v is None for v in single):
+ p.error("single-shape mode requires all of --N --C --H --W --K --R --S")
+ args.single_shape = True
+ else:
+ args.single_shape = False
+ return args
+
+
+def main(argv: Optional[list[str]] = None) -> None:
+ args = parse_args(argv)
+ if args.single_shape:
+ run_single_shape(args)
+ else:
+ run_sweep(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py b/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py
new file mode 100644
index 0000000000..899deefae1
--- /dev/null
+++ b/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py
@@ -0,0 +1,230 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+"""Offline extraction of Conv2d layer shapes from real model checkpoints.
+
+Run once per model. Walks the model with forward hooks, captures the
+per-layer (N, C, H, W, K, R, S, stride, pad, dilation) tuples, dedupes,
+and prints JSON ready to paste under ``"": {"conv2d": [...]}``
+in ``model_shapes.json``.
+
+Consumed by ``bench_conv2d.py --model NAME`` (sweep mode).
+
+Usage::
+
+ python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\
+ --model resnet50
+
+ python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\
+ --model sd35_vae --model-path /app/models/stable-diffusion-3.5-medium
+
+ python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\
+ --model flux2_vae --model-path /app/models/FLUX.2-klein-9B
+
+Weight values do not affect the captured shapes (Conv2d dimensions are
+fixed by the layer definition, not the weight tensor contents). Concretely:
+- resnet50 uses ``weights=None`` (random init, no checkpoint download).
+- sd35_vae / flux2_vae use ``AutoencoderKL[Flux2].from_pretrained(...)``,
+ which loads the local checkpoint at ``--model-path`` (architecture +
+ weights). The weights are unused for shape extraction; they come along
+ because diffusers' easiest API loads both.
+
+Skips ``groups != 1`` layers (kernel doesn't support grouped/depthwise yet)
+and ``padding_mode != 'zeros'`` layers.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+
+def _conv_shape_dict(mod: nn.Conv2d, x_shape: torch.Size) -> dict:
+ _, _, H, W = x_shape
+ R, S = (
+ mod.kernel_size
+ if isinstance(mod.kernel_size, tuple)
+ else (mod.kernel_size, mod.kernel_size)
+ )
+ sh, sw = mod.stride if isinstance(mod.stride, tuple) else (mod.stride, mod.stride)
+ ph, pw = (
+ mod.padding if isinstance(mod.padding, tuple) else (mod.padding, mod.padding)
+ )
+ dh, dw = (
+ mod.dilation
+ if isinstance(mod.dilation, tuple)
+ else (mod.dilation, mod.dilation)
+ )
+ return {
+ "N": int(x_shape[0]),
+ "C": int(mod.in_channels),
+ "H": int(H),
+ "W": int(W),
+ "K": int(mod.out_channels),
+ "R": int(R),
+ "S": int(S),
+ "stride_h": int(sh),
+ "stride_w": int(sw),
+ "pad_h": int(ph),
+ "pad_w": int(pw),
+ "dilation_h": int(dh),
+ "dilation_w": int(dw),
+ }
+
+
+def _walk(module: nn.Module, run_forward, batch_size: int) -> list[dict]:
+ """Hook every Conv2d, run forward once, return deduped shape dicts."""
+ captured: dict[str, torch.Size] = {}
+
+ def make_hook(key):
+ def hook(m, inp, out):
+ if key not in captured and inp and isinstance(inp[0], torch.Tensor):
+ captured[key] = inp[0].shape
+
+ return hook
+
+ handles = []
+ for name, mod in module.named_modules():
+ if isinstance(mod, nn.Conv2d):
+ handles.append(mod.register_forward_hook(make_hook(name)))
+ try:
+ with torch.no_grad():
+ run_forward()
+ finally:
+ for h in handles:
+ h.remove()
+
+ shapes = []
+ skipped_groups = 0
+ skipped_padmode = 0
+ skipped_no_shape = 0
+ for name, mod in module.named_modules():
+ if not isinstance(mod, nn.Conv2d):
+ continue
+ if mod.groups != 1:
+ skipped_groups += 1
+ continue
+ if getattr(mod, "padding_mode", "zeros") != "zeros":
+ skipped_padmode += 1
+ continue
+ if name not in captured:
+ skipped_no_shape += 1
+ continue
+ shapes.append(_conv_shape_dict(mod, captured[name]))
+
+ print(
+ f"# captured {len(shapes)} conv layers "
+ f"(skipped: {skipped_groups} grouped, {skipped_padmode} non-zero pad-mode, "
+ f"{skipped_no_shape} unreached)",
+ file=sys.stderr,
+ )
+
+ # Dedupe — many models have the same shape repeated across blocks.
+ seen = set()
+ deduped = []
+ for s in shapes:
+ key = tuple(s.items())
+ if key in seen:
+ continue
+ seen.add(key)
+ deduped.append(s)
+ print(f"# {len(deduped)} unique shapes after dedupe", file=sys.stderr)
+ return deduped
+
+
+# Hardcoded internally — these don't affect captured shape values, only the
+# forward-pass runtime. fp16 + cuda is fine for every model we extract from.
+_DTYPE = torch.float16
+_DEVICE = "cuda"
+
+
+def extract_resnet50(N: int, H: int, W: int) -> list[dict]:
+ from torchvision.models import resnet50
+
+ model = resnet50(weights=None).to(device=_DEVICE, dtype=_DTYPE).eval()
+ x = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE)
+
+ def fwd():
+ model(x)
+
+ return _walk(model, fwd, N)
+
+
+def extract_sd35_vae(model_path: str, N: int, H: int, W: int) -> list[dict]:
+ from diffusers import AutoencoderKL
+
+ sub = model_path if model_path.rstrip("/").endswith("/vae") else model_path + "/vae"
+ vae = AutoencoderKL.from_pretrained(sub, torch_dtype=_DTYPE).to(_DEVICE).eval()
+ img = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE)
+
+ def fwd():
+ latent = vae.encode(img).latent_dist.sample()
+ vae.decode(latent)
+
+ return _walk(vae, fwd, N)
+
+
+def extract_flux2_vae(model_path: str, N: int, H: int, W: int) -> list[dict]:
+ from diffusers import AutoencoderKLFlux2
+
+ sub = model_path if model_path.rstrip("/").endswith("/vae") else model_path + "/vae"
+ vae = AutoencoderKLFlux2.from_pretrained(sub, torch_dtype=_DTYPE).to(_DEVICE).eval()
+ img = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE)
+
+ def fwd():
+ latent = vae.encode(img).latent_dist.sample()
+ vae.decode(latent)
+
+ return _walk(vae, fwd, N)
+
+
+def main(argv: Optional[list[str]] = None) -> None:
+ p = argparse.ArgumentParser(prog="extract_conv_shapes", description=__doc__)
+ p.add_argument(
+ "--model", required=True, choices=["resnet50", "sd35_vae", "flux2_vae"]
+ )
+ p.add_argument(
+ "--model-path",
+ default=None,
+ help="local checkpoint dir (required for sd35_vae / flux2_vae)",
+ )
+ p.add_argument("--N", type=int, default=1, help="batch size")
+ p.add_argument(
+ "--H", type=int, default=None, help="input H (default: model-typical)"
+ )
+ p.add_argument(
+ "--W", type=int, default=None, help="input W (default: model-typical)"
+ )
+ args = p.parse_args(argv)
+
+ if args.model == "resnet50":
+ H = args.H if args.H is not None else 224
+ W = args.W if args.W is not None else 224
+ shapes = extract_resnet50(args.N, H, W)
+ model_key = "resnet50"
+ elif args.model == "sd35_vae":
+ if not args.model_path:
+ p.error("--model-path required for sd35_vae")
+ H = args.H if args.H is not None else 512
+ W = args.W if args.W is not None else 512
+ shapes = extract_sd35_vae(args.model_path, args.N, H, W)
+ model_key = "stable-diffusion-3.5-medium"
+ elif args.model == "flux2_vae":
+ if not args.model_path:
+ p.error("--model-path required for flux2_vae")
+ H = args.H if args.H is not None else 512
+ W = args.W if args.W is not None else 512
+ shapes = extract_flux2_vae(args.model_path, args.N, H, W)
+ model_key = "FLUX.2-klein-9B"
+ else:
+ raise AssertionError("unreachable")
+
+ print(json.dumps({model_key: {"conv2d": shapes}}, indent=4))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json b/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json
index 0fd7d93f6a..afa49c92ef 100644
--- a/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json
+++ b/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json
@@ -44,12 +44,12 @@
"TP_dim": "K"
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 16384
}
],
- "rope": [
+ "rope": [
{
"num_heads": 128,
"num_kv_heads": 8,
@@ -121,12 +121,12 @@
"TP_dim": "K"
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 8192
}
],
- "rope": [
+ "rope": [
{
"num_heads": 64,
"num_kv_heads": 8,
@@ -198,12 +198,12 @@
"TP_dim": "K"
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 4096
}
],
- "rope": [
+ "rope": [
{
"num_heads": 32,
"num_kv_heads": 8,
@@ -256,12 +256,12 @@
"TopK": 4
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 2880
}
],
- "rope": [
+ "rope": [
{
"num_heads": 64,
"num_kv_heads": 8,
@@ -441,7 +441,7 @@
"TopK": 8
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 7168
},
@@ -452,7 +452,7 @@
"N": 512
}
],
- "rope": [
+ "rope": [
{
"num_heads": 128,
"num_kv_heads": 128,
@@ -537,12 +537,12 @@
"TopK": 1
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 5120
}
],
- "rope": [
+ "rope": [
{
"num_heads": 40,
"num_kv_heads": 8,
@@ -605,7 +605,7 @@
"TopK": 8
}
],
- "rmsnorm": [
+ "rmsnorm": [
{
"N": 4096
},
@@ -613,7 +613,7 @@
"N": 128
}
],
- "rope": [
+ "rope": [
{
"num_heads": 64,
"num_kv_heads": 4,
@@ -639,5 +639,1022 @@
"dv": 128
}
]
+ },
+ "resnet50": {
+ "conv2d": [
+ {
+ "N": 1,
+ "C": 3,
+ "H": 224,
+ "W": 224,
+ "K": 64,
+ "R": 7,
+ "S": 7,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 3,
+ "pad_w": 3,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 64,
+ "H": 56,
+ "W": 56,
+ "K": 64,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 64,
+ "H": 56,
+ "W": 56,
+ "K": 64,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 64,
+ "H": 56,
+ "W": 56,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 56,
+ "W": 56,
+ "K": 64,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 56,
+ "W": 56,
+ "K": 128,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 56,
+ "W": 56,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 28,
+ "W": 28,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 56,
+ "W": 56,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 28,
+ "W": 28,
+ "K": 128,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 28,
+ "W": 28,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 28,
+ "W": 28,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 28,
+ "W": 28,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 14,
+ "W": 14,
+ "K": 1024,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 28,
+ "W": 28,
+ "K": 1024,
+ "R": 1,
+ "S": 1,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 1024,
+ "H": 14,
+ "W": 14,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 14,
+ "W": 14,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 1024,
+ "H": 14,
+ "W": 14,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 14,
+ "W": 14,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 7,
+ "W": 7,
+ "K": 2048,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 1024,
+ "H": 14,
+ "W": 14,
+ "K": 2048,
+ "R": 1,
+ "S": 1,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 2048,
+ "H": 7,
+ "W": 7,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 7,
+ "W": 7,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ }
+ ]
+ },
+ "stable-diffusion-3.5-medium": {
+ "conv2d": [
+ {
+ "N": 1,
+ "C": 3,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 513,
+ "W": 513,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 257,
+ "W": 257,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 129,
+ "W": 129,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 64,
+ "W": 64,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 64,
+ "W": 64,
+ "K": 32,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 16,
+ "H": 64,
+ "W": 64,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 512,
+ "W": 512,
+ "K": 3,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ }
+ ]
+ },
+ "FLUX.2-klein-9B": {
+ "conv2d": [
+ {
+ "N": 1,
+ "C": 3,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 513,
+ "W": 513,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 257,
+ "W": 257,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 128,
+ "W": 128,
+ "K": 512,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 129,
+ "W": 129,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 2,
+ "stride_w": 2,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 64,
+ "W": 64,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 64,
+ "W": 64,
+ "K": 64,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 32,
+ "H": 64,
+ "W": 64,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 512,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 512,
+ "H": 256,
+ "W": 256,
+ "K": 256,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 256,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 256,
+ "H": 512,
+ "W": 512,
+ "K": 128,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 128,
+ "H": 512,
+ "W": 512,
+ "K": 3,
+ "R": 3,
+ "S": 3,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 1,
+ "pad_w": 1,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 64,
+ "H": 64,
+ "W": 64,
+ "K": 64,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ },
+ {
+ "N": 1,
+ "C": 32,
+ "H": 64,
+ "W": 64,
+ "K": 32,
+ "R": 1,
+ "S": 1,
+ "stride_h": 1,
+ "stride_w": 1,
+ "pad_h": 0,
+ "pad_w": 0,
+ "dilation_h": 1,
+ "dilation_w": 1
+ }
+ ]
}
-}
\ No newline at end of file
+}
diff --git a/op_tests/triton_tests/conv/__init__.py b/op_tests/triton_tests/conv/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/op_tests/triton_tests/conv/_helpers.py b/op_tests/triton_tests/conv/_helpers.py
new file mode 100644
index 0000000000..2b1bfda14d
--- /dev/null
+++ b/op_tests/triton_tests/conv/_helpers.py
@@ -0,0 +1,483 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+"""Test-side library: TestSuite, method registry, and runners.
+
+Library code (no ``test_`` prefix) — pytest does not collect this file.
+``test_conv2d.py`` imports the runners and registry from here.
+
+Public surface:
+
+- ``TestSuite`` / ``TestResult``
+ Correctness collector. ``check_close`` records pass/fail per case;
+ ``failed_results`` returns the list at end of run.
+
+- ``METHOD_REGISTRY`` (+ ``ORDERED_METHODS`` / ``ALL_METHODS``)
+ Kernel dispatch table. Each entry maps a method name to its
+ public ``conv2d_*`` callable, applicability guard, winograd flag,
+ and bench tag.
+
+- ``run_all_methods(...)``
+ Main dispatch. For a given (x, w, b, stride, padding, dilation):
+ runs the selected NCHW kernel (or all of them, if method="all"),
+ runs ``conv2d_nhwc`` if layout_mode includes nhwc, and checks every
+ output against ``F.conv2d`` within method-appropriate tolerance.
+
+- ``run_edge_cases``, ``run_random_fuzzing``, ``run_no_bias``,
+ ``run_activations``, ``run_cross_method``
+ Test runners called by ``test_conv2d.py``. Each iterates a shape
+ set and delegates to ``run_all_methods``.
+
+- ``COMMON_SHAPES``, ``get_edge_case_shapes()``
+ Shared shape data — 3x3 stride-1 shapes routable by every kernel,
+ and the 12-shape edge-case list respectively.
+"""
+
+from __future__ import annotations
+import random
+import traceback
+from collections import namedtuple
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from aiter.ops.triton.conv._utils import (
+ dynamic_conv_tolerances,
+ _out_hw,
+ _is_1x1_conv,
+ _is_3x3_conv,
+ _winograd_tolerances,
+ apply_activation,
+)
+from aiter.ops.triton.conv._launch import _select_3x3_method
+from aiter.ops.triton.conv.conv2d import (
+ conv2d_nchw,
+ conv2d_nchw_cblocked,
+ conv2d_nhwc,
+ conv2d_winograd_f4x3,
+ conv2d_winograd_f4x3_fused,
+ conv2d_winograd_f4x3_cblocked,
+)
+
+# -- Architecture gating ------------------------------------------------------
+SUPPORTED_ARCHS = {
+ "RDNA": {"gfx1200", "gfx1201"},
+ "CDNA": set(),
+}
+
+# Flat union for arch-check use sites.
+ALL_SUPPORTED_ARCHS = set().union(*SUPPORTED_ARCHS.values())
+
+
+# -- Method registry ----------------------------------------------------------
+
+MethodEntry = namedtuple(
+ "MethodEntry", ["kernel_fn", "guard_fn", "is_winograd", "bench_tag", "short_name"]
+)
+
+
+def _3x3_guard(R, S, stride, dilation, C):
+ return _is_3x3_conv(R, S)
+
+
+def _wino_guard(R, S, stride, dilation, C):
+ # _is_winograd_eligible signature varies by upstream — keep the flag tight
+ from aiter.ops.triton.conv._utils import _is_winograd_eligible
+
+ return _is_winograd_eligible(R, S, stride, dilation, C)
+
+
+METHOD_REGISTRY = {
+ "default": MethodEntry(conv2d_nchw, None, False, "", "default"),
+ "cblocked": MethodEntry(
+ conv2d_nchw_cblocked, _3x3_guard, False, "[cblocked]", "cblocked"
+ ),
+ "winograd_f4x3": MethodEntry(
+ conv2d_winograd_f4x3, _wino_guard, True, "[winograd_f4x3]", "WF(4,3)"
+ ),
+ "winograd_f4x3_fused": MethodEntry(
+ conv2d_winograd_f4x3_fused, _wino_guard, True, "[wino_f4x3_fused]", "WF4fused"
+ ),
+ "winograd_f4x3_cblocked": MethodEntry(
+ conv2d_winograd_f4x3_cblocked,
+ _wino_guard,
+ True,
+ "[winograd_f4x3_cblocked]",
+ "WF4cb",
+ ),
+}
+
+ORDERED_METHODS = list(METHOD_REGISTRY.keys())
+ALL_METHODS = ORDERED_METHODS + ["all"]
+
+
+# -- Result + suite -----------------------------------------------------------
+
+
+@dataclass
+class TestResult:
+ name: str
+ passed: bool
+ max_abs_error: float
+ rel_error: float
+ message: str = ""
+
+
+class TestSuite:
+ """Correctness-only test runner. No bench records, no MIOpen tables."""
+
+ __test__ = False # not a pytest TestCase — leading "Test" is incidental
+
+ def __init__(
+ self,
+ device: str,
+ dtype: torch.dtype,
+ verbose: bool = False,
+ print_shapes: bool = False,
+ layout_mode: str = "both",
+ ):
+ self.device = torch.device(device)
+ self.dtype = dtype
+ self.verbose = verbose
+ self.print_shapes = print_shapes
+ self.layout_mode = layout_mode
+ self.results: List[TestResult] = []
+
+ def check_close(
+ self,
+ name: str,
+ got: torch.Tensor,
+ ref: torch.Tensor,
+ K_red: Optional[int] = None,
+ rtol: Optional[float] = None,
+ atol: Optional[float] = None,
+ ) -> TestResult:
+ got32 = got.float()
+ ref32 = ref.float()
+ diff = (got32 - ref32).abs()
+ max_abs = float(diff.max().item()) if diff.numel() else 0.0
+ rel = max_abs / (float(ref32.abs().max().item()) + 1e-6)
+ if rtol is None or atol is None:
+ K_est = int(K_red) if K_red is not None else 1024
+ rtol_calc, atol_calc = dynamic_conv_tolerances(self.dtype, K_est, ref32)
+ rtol = rtol if rtol is not None else rtol_calc
+ atol = atol if atol is not None else atol_calc
+ try:
+ torch.testing.assert_close(got32, ref32, rtol=rtol, atol=atol)
+ passed = True
+ msg = "OK"
+ except AssertionError as e:
+ passed = False
+ msg = str(e).split("\n")[0]
+ res = TestResult(name, passed, max_abs, rel, msg)
+ self.results.append(res)
+ if self.verbose:
+ mark = "✓" if passed else "✗"
+ print(f" {mark} {name:<40} | max_abs={max_abs:.3e} rel={rel:.3e}")
+ return res
+
+ def all_passed(self) -> bool:
+ return all(r.passed for r in self.results)
+
+ def failed_results(self) -> List[TestResult]:
+ return [r for r in self.results if not r.passed]
+
+
+# -- Tolerance + dispatch -----------------------------------------------------
+
+
+def _get_tolerances(
+ method_name, entry, suite, y_ref, N, C, H, W, K_out, R, S, stride, dilation
+):
+ if entry.is_winograd:
+ return _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3")
+ if method_name == "default" and _is_3x3_conv(R, S):
+ routed = _select_3x3_method(N, C, H, W, K_out, stride, dilation)
+ if routed and "winograd" in routed:
+ return _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3")
+ return dynamic_conv_tolerances(suite.dtype, C * R * S, y_ref)
+
+
+def run_all_methods(
+ suite: TestSuite,
+ x: torch.Tensor,
+ w: torch.Tensor,
+ b: Optional[torch.Tensor],
+ stride,
+ padding,
+ dilation,
+ name: str,
+ method: str = "default",
+ activation: str = "none",
+):
+ """Correctness-only dispatch: run selected method(s) and check vs F.conv2d."""
+ N, C, H, W_in = x.shape
+ K_out, _, R, S = w.shape
+
+ y_ref = F.conv2d(
+ x,
+ w,
+ b.to(dtype=suite.dtype) if b is not None else None,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ )
+ y_ref = apply_activation(y_ref, activation)
+
+ if suite.print_shapes:
+ if _is_1x1_conv(R, S, dilation):
+ kernel_type = "[1x1]"
+ elif _is_3x3_conv(R, S):
+ kernel_type = "[3x3]"
+ else:
+ kernel_type = "[general]"
+ print(
+ f" {name} {kernel_type}: X{tuple(x.shape)} W{tuple(w.shape)} -> Y{tuple(y_ref.shape)}"
+ )
+
+ if suite.layout_mode in ("nchw", "both"):
+ methods_to_run = ORDERED_METHODS if method == "all" else [method]
+ for m in methods_to_run:
+ entry = METHOD_REGISTRY[m]
+ if entry.guard_fn and not entry.guard_fn(R, S, stride, dilation, C):
+ continue
+ y_tri = entry.kernel_fn(
+ x,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ activation=activation,
+ out_dtype=suite.dtype,
+ )
+ rtol, atol = _get_tolerances(
+ m, entry, suite, y_ref, N, C, H, W_in, K_out, R, S, stride, dilation
+ )
+ suite.check_close(
+ f"{name} {entry.bench_tag or '[NCHW]'}",
+ y_tri,
+ y_ref,
+ rtol=rtol,
+ atol=atol,
+ )
+
+ if suite.layout_mode in ("nhwc", "both"):
+ y_nhwc = conv2d_nhwc(
+ x,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ activation=activation,
+ out_dtype=suite.dtype,
+ )
+ if _is_3x3_conv(R, S):
+ nhwc_method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation)
+ if nhwc_method in ("winograd_f4x3", "winograd_f4x3_cblocked"):
+ _r, _a = _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3")
+ suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, rtol=_r, atol=_a)
+ else:
+ suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, K_red=C * R * S)
+ else:
+ suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, K_red=C * R * S)
+
+
+# -- Shape sets ---------------------------------------------------------------
+
+# Shapes routable by ALL 5 NCHW kernels (3x3, stride=1, padding=1, dilation=1,
+# C >= 4). Used by test_cross_method to verify every kernel produces the same
+# result (within tolerance) on the same input.
+COMMON_SHAPES = [
+ (1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), (1, 1), "common 64ch/56sp"),
+ (1, 128, 28, 28, 128, 3, 3, (1, 1), (1, 1), (1, 1), "common 128ch/28sp"),
+ (1, 256, 14, 14, 256, 3, 3, (1, 1), (1, 1), (1, 1), "common 256ch/14sp"),
+]
+
+
+# -- Edge case shapes ---------------------------------------------------------
+
+
+def get_edge_case_shapes():
+ return [
+ (1, 3, 7, 7, 8, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 same padding"),
+ (1, 3, 8, 8, 16, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 stride1"),
+ (2, 16, 32, 32, 32, 3, 3, (2, 2), (1, 1), (1, 1), "stride2"),
+ (2, 32, 17, 23, 64, 5, 5, (2, 2), (2, 2), (1, 1), "odd dims + pad"),
+ (4, 64, 28, 28, 128, 3, 3, (1, 1), (0, 0), (2, 2), "dilation2"),
+ (2, 512, 7, 7, 1024, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 large channels"),
+ (1, 3, 112, 112, 64, 7, 7, (2, 2), (3, 3), (1, 1), "7x7 large spatial"),
+ (1, 1, 16, 16, 16, 3, 3, (1, 1), (1, 1), (1, 1), "single input channel"),
+ (2, 64, 8, 8, 64, 3, 3, (1, 1), (1, 1), (1, 1), "small spatial 3x3"),
+ (1, 128, 4, 4, 256, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 tiny spatial"),
+ (2, 32, 32, 32, 32, 3, 3, (1, 1), (0, 0), (1, 1), "3x3 no padding"),
+ (2, 64, 28, 28, 128, 3, 3, (2, 2), (1, 1), (1, 1), "3x3 stride2 standard"),
+ ]
+
+
+# -- Test runners (no test_ prefix; pytest will not collect this file) --------
+
+
+def run_edge_cases(suite: TestSuite, activation: str = "none", method: str = "default"):
+ for (
+ N,
+ C,
+ H,
+ W,
+ K_out,
+ R,
+ S,
+ stride,
+ padding,
+ dilation,
+ desc,
+ ) in get_edge_case_shapes():
+ P, Q = _out_hw(H, W, R, S, stride, padding, dilation)
+ if P < 1 or Q < 1:
+ continue
+ x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype)
+ w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype)
+ b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype)
+ run_all_methods(
+ suite,
+ x,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ name=desc,
+ method=method,
+ activation=activation,
+ )
+
+
+def run_activations(
+ suite: TestSuite, method: str = "default", activation: str = "relu"
+):
+ N, C, H, W, K_out = 2, 32, 16, 16, 64
+ R, S = 3, 3
+ stride, padding, dilation = (1, 1), (1, 1), (1, 1)
+ x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype)
+ w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype)
+ b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype)
+ run_all_methods(
+ suite,
+ x,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ name=f"activation_{activation}_{method}",
+ method=method,
+ activation=activation,
+ )
+
+
+def run_no_bias(suite: TestSuite, method: str = "default"):
+ shapes = [
+ (1, 64, 8, 8, 128, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 no bias"),
+ (2, 32, 16, 16, 64, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 no bias"),
+ (1, 16, 8, 8, 32, 5, 5, (1, 1), (2, 2), (1, 1), "5x5 no bias"),
+ ]
+ for N, C, H, W, K_out, R, S, stride, padding, dilation, desc in shapes:
+ x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype)
+ w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype)
+ run_all_methods(
+ suite,
+ x,
+ w,
+ None,
+ stride,
+ padding,
+ dilation,
+ name=desc,
+ method=method,
+ )
+
+
+def run_cross_method(suite: TestSuite):
+ """Run every NCHW-applicable kernel on shapes that all 5 can handle.
+
+ Each kernel is checked against F.conv2d. Transitivity gives us
+ cross-kernel equivalence: if kernel A and B both match the same
+ F.conv2d output within tolerance, they match each other within ~2x.
+ """
+ for N, C, H, W, K_out, R, S, stride, padding, dilation, desc in COMMON_SHAPES:
+ x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype)
+ w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype)
+ b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype)
+ run_all_methods(
+ suite,
+ x,
+ w,
+ b,
+ stride,
+ padding,
+ dilation,
+ name=desc,
+ method="all", # iterates every kernel in ORDERED_METHODS
+ )
+
+
+def run_random_fuzzing(
+ suite: TestSuite,
+ num_tests: int = 10,
+ activation: str = "none",
+ method: str = "default",
+ seed: int = 42,
+):
+ """Bounded random shape sweep, seeded for reproducibility.
+
+ Default num_tests=10 keeps CI cheap; callers can pass a larger value
+ for ad-hoc development sweeps.
+ """
+ random.seed(seed)
+ for i in range(num_tests):
+ N = random.randint(1, 8)
+ C = random.choice([1, 3, 16, 32, 64, 128, 256])
+ H = random.randint(4, 64)
+ W = random.randint(4, 64)
+ K_out = random.choice([16, 32, 64, 128, 256])
+ R = random.randint(1, min(7, H))
+ S = random.randint(1, min(7, W))
+ sh = random.randint(1, 3)
+ sw = random.randint(1, 3)
+ ph = random.randint(0, R // 2)
+ pw = random.randint(0, S // 2)
+ dh = random.randint(1, 2)
+ dw = random.randint(1, 2)
+ P, Q = _out_hw(H, W, R, S, (sh, sw), (ph, pw), (dh, dw))
+ if P < 1 or Q < 1:
+ continue
+ try:
+ x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype)
+ w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype)
+ b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype)
+ tag = f"Random[{i}] ({N},{C},{H},{W})->({N},{K_out},{P},{Q})"
+ run_all_methods(
+ suite,
+ x,
+ w,
+ b,
+ (sh, sw),
+ (ph, pw),
+ (dh, dw),
+ name=tag,
+ method=method,
+ activation=activation,
+ )
+ except Exception as e:
+ tb = traceback.format_exc()
+ suite.results.append(
+ TestResult(
+ f"Random[{i}]",
+ False,
+ float("inf"),
+ float("inf"),
+ f"{type(e).__name__}: {e}\n{tb}",
+ )
+ )
diff --git a/op_tests/triton_tests/conv/test_conv2d.py b/op_tests/triton_tests/conv/test_conv2d.py
new file mode 100644
index 0000000000..c129e5f39b
--- /dev/null
+++ b/op_tests/triton_tests/conv/test_conv2d.py
@@ -0,0 +1,130 @@
+# SPDX-License-Identifier: MIT
+# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
+"""Pytest unit tests for aiter.ops.triton.conv.conv2d.
+
+Correctness only. All tests compare Triton kernels against
+torch.nn.functional.conv2d on synthetic tensors. No model loading,
+no network, no torchvision.
+
+Test matrix (uniform across the four primary test families):
+
+ NCHW × {fp16, bf16} × every kernel (5 kernels) = 10
+ NHWC × {fp16, bf16} (single dispatch) = 2
+ ---
+ base cases per test family 12
+
+test_edge, test_fuzz, test_no_bias use the base matrix as-is.
+test_activations multiplies the base matrix by 3 (relu/relu6/gelu) -> 36.
+
+Plus test_cross_method (differential correctness) that runs every NCHW
+kernel on shapes routable by all of them and verifies they all match
+F.conv2d. NCHW-only by design; 2 cases (one per dtype).
+
+Total: 12 + 12 + 12 + 36 + 2 = 74 cases.
+
+Where a kernel's guard rejects a shape (e.g. winograd on a 5x5), the
+shape is silently skipped inside run_all_methods.
+
+Performance benchmarking lives in
+op_tests/op_benchmarks/triton/bench_conv2d.py (and, for real-model
+shapes, in op_benchmarks/triton/model_benchmarking_tool/bench_models.py).
+"""
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+from aiter.ops.triton.utils._triton.arch_info import get_arch
+
+from ._helpers import (
+ ALL_SUPPORTED_ARCHS,
+ TestSuite,
+ ORDERED_METHODS,
+ run_edge_cases,
+ run_activations,
+ run_no_bias,
+ run_random_fuzzing,
+ run_cross_method,
+)
+
+# Module-level arch gate. Skip the whole test module on unsupported archs
+# rather than fail per-test. Extend SUPPORTED_ARCHS in _helpers.py when
+# adding CDNA (or other RDNA) support.
+_current_arch = get_arch()
+if _current_arch not in ALL_SUPPORTED_ARCHS:
+ pytest.skip(
+ f"aiter.ops.triton.conv tests run on {sorted(ALL_SUPPORTED_ARCHS)}; "
+ f"current arch {_current_arch!r} not supported",
+ allow_module_level=True,
+ )
+
+
+# Build the (dtype, layout, method) matrix once. NHWC entries only pair with
+# method="default" because conv2d_nhwc is single-dispatch — the method param
+# is a no-op there, so re-running for every method id would just duplicate work.
+def _build_matrix():
+ matrix = []
+ for dtype, dtype_id in [(torch.float16, "fp16"), (torch.bfloat16, "bf16")]:
+ for method in ORDERED_METHODS:
+ matrix.append(((dtype, "nchw", method), f"{dtype_id}_nchw_{method}"))
+ matrix.append(((dtype, "nhwc", "default"), f"{dtype_id}_nhwc"))
+ return matrix
+
+
+_MATRIX = _build_matrix()
+PARAMS = [params for params, _ in _MATRIX]
+IDS = [tid for _, tid in _MATRIX]
+
+
+def _make_suite(dtype, layout):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+ return TestSuite(device="cuda", dtype=dtype, layout_mode=layout)
+
+
+def _assert_suite(suite: TestSuite):
+ failed = suite.failed_results()
+ assert not failed, f"{len(failed)} tests failed: {[r.name for r in failed]}"
+
+
+# -- The four primary test families, all on the same matrix ------------------
+
+
+@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS)
+def test_edge(dtype, layout, method):
+ suite = _make_suite(dtype, layout)
+ run_edge_cases(suite, method=method)
+ _assert_suite(suite)
+
+
+@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS)
+def test_fuzz(dtype, layout, method):
+ suite = _make_suite(dtype, layout)
+ run_random_fuzzing(suite, num_tests=10, method=method)
+ _assert_suite(suite)
+
+
+@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS)
+def test_no_bias(dtype, layout, method):
+ suite = _make_suite(dtype, layout)
+ run_no_bias(suite, method=method)
+ _assert_suite(suite)
+
+
+@pytest.mark.parametrize("activation", ["relu", "relu6", "gelu"])
+@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS)
+def test_activations(dtype, layout, method, activation):
+ suite = _make_suite(dtype, layout)
+ run_activations(suite, method=method, activation=activation)
+ _assert_suite(suite)
+
+
+# -- Differential correctness across all 5 NCHW kernels (NCHW-only) ----------
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
+def test_cross_method(dtype):
+ suite = _make_suite(dtype, "nchw")
+ run_cross_method(suite)
+ _assert_suite(suite)