diff --git a/aiter/ops/triton/_triton_kernels/conv/__init__.py b/aiter/ops/triton/_triton_kernels/conv/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py b/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py new file mode 100644 index 0000000000..d4ed10578f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/conv_1x1.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import triton +import triton.language as tl + +from .helpers import _tanh, AUTOTUNE_1x1_CONFIGS + + +@triton.autotune( + configs=AUTOTUNE_1x1_CONFIGS, + key=["M_total", "K_out", "C"], + reset_to_zero=["Y"], + warmup=50, + rep=200, + cache_results=True, +) +@triton.jit +def _conv2d_1x1_kernel( + X, + W, + BIAS, + Y, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + K_out: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + M_total: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, + LAYOUT: tl.constexpr, +): + """ + Specialized 1x1 convolution kernel. + - No R*S loop (R=S=1) + - Direct channel reduction + - Simplified pointer arithmetic + LAYOUT: 0=NCHW, 1=NHWC + """ + # W is always [K_out, C] contiguous + stride_w_k: tl.constexpr = C + stride_w_c: tl.constexpr = 1 + if LAYOUT == 0: + # NCHW: X[N, C, H, W_in], Y[N, K_out, P, Q] + stride_x_n: tl.constexpr = C * H * W_in + stride_x_c: tl.constexpr = H * W_in + stride_x_h: tl.constexpr = W_in + stride_x_w: tl.constexpr = 1 + stride_y_n: tl.constexpr = K_out * P * Q + stride_y_k: tl.constexpr = P * Q + stride_y_p: tl.constexpr = Q + stride_y_q: tl.constexpr = 1 + else: + # NHWC: X[N, H, W_in, C], Y[N, P, Q, K_out] + stride_x_n: tl.constexpr = H * W_in * C + stride_x_c: tl.constexpr = 1 + stride_x_h: tl.constexpr = W_in * C + stride_x_w: tl.constexpr = C + stride_y_n: tl.constexpr = P * Q * K_out + stride_y_k: tl.constexpr = 1 + stride_y_p: tl.constexpr = Q * K_out + stride_y_q: tl.constexpr = K_out + + pid = tl.program_id(axis=0) + + # M = N * P * Q (output spatial), N_dim = K_out (output channels) + num_pid_m = tl.cdiv(M_total, BLOCK_M) + num_pid_n = tl.cdiv(K_out, BLOCK_N) + + # L2 cache swizzle pattern + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m >= num_pid_m: + return + + # Compute output tile indices + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + # Decode (n, p, q) from linear index + n_idx = offs_m // (P * Q) + pq = offs_m % (P * Q) + p_idx = pq // Q + q_idx = pq % Q + + # Valid output mask + m_mask = offs_m < M_total + n_mask = offs_n < K_out + + ih = p_idx * stride_h - pad_h + iw = q_idx * stride_w - pad_w + + # Check spatial bounds + spatial_valid = (ih >= 0) & (ih < H) & (iw >= 0) & (iw < W_in) & (n_idx < N) + + # Base pointers for this output tile + x_base = X + n_idx * stride_x_n + ih * stride_x_h + iw * stride_x_w # [BLOCK_M] + w_base = W + offs_n * stride_w_k # [BLOCK_N] + + # Accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Channel reduction loop + for k0 in range(0, C, BLOCK_K): + k_offs = k0 + offs_k + k_mask = k_offs < C + + # Load input: X[n, c, ih, iw] -> shape [BLOCK_M, BLOCK_K] + x_ptrs = x_base[:, None] + k_offs[None, :] * stride_x_c + x_mask = spatial_valid[:, None] & k_mask[None, :] + x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0) + + # Load weight: W[k_out, c] -> shape [BLOCK_K, BLOCK_N] + w_ptrs = w_base[None, :] + k_offs[:, None] * stride_w_c + w_mask = k_mask[:, None] & n_mask[None, :] + w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0) + + # Accumulate: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] + acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32) + + # Bias + if HAS_BIAS: + b = tl.load(BIAS + offs_n, mask=n_mask, other=0.0) + acc += b[None, :] + + # Activation + if ACT_TYPE == 1: # ReLU + acc = tl.maximum(acc, 0) + elif ACT_TYPE == 2: # ReLU6 + acc = tl.minimum(tl.maximum(acc, 0), 6) + elif ACT_TYPE == 3: # GELU + acc = ( + 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc))) + ) + + # Store output: Y[n, k, p, q] + y_ptrs = ( + Y + + n_idx[:, None] * stride_y_n + + offs_n[None, :] * stride_y_k + + p_idx[:, None] * stride_y_p + + q_idx[:, None] * stride_y_q + ) + y_mask = m_mask[:, None] & n_mask[None, :] + tl.store(y_ptrs, acc, mask=y_mask) diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py b/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py new file mode 100644 index 0000000000..650028c0c5 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/conv_3x3.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import triton +import triton.language as tl +from .helpers import ( + _tanh, + AUTOTUNE_3x3_NHWC_CONFIGS, + AUTOTUNE_3x3_CBLOCKED_CONFIGS, +) + + +@triton.autotune( + configs=AUTOTUNE_3x3_NHWC_CONFIGS, + key=["M_total", "K_out", "C_pad"], + reset_to_zero=["Y"], + warmup=50, + rep=200, + cache_results=True, +) +@triton.jit +def _conv2d_3x3_nhwc_kernel( + X, + W3, + BIAS, + Y, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + K_out: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + C_pad: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + dil_h: tl.constexpr, + dil_w: tl.constexpr, + M_total: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, +): + """Specialized 3x3 NHWC kernel: stride_x_c=1 and stride_y_k=1 hardcoded + so the compiler can emit coalesced vector loads/stores.""" + # X layout: [N, H, W_in, C] contiguous NHWC (stride_x_c=1 hardcoded in load logic) + stride_x_w: tl.constexpr = C + stride_x_h: tl.constexpr = W_in * C + stride_x_n: tl.constexpr = H * W_in * C + # W3 layout: [K_out, 9, C_pad] contiguous + stride_w3_c: tl.constexpr = 1 + stride_w3_rs: tl.constexpr = C_pad + stride_w3_kout: tl.constexpr = 9 * C_pad + # Y layout: [N, P, Q, K_out] contiguous NHWC (stride_y_k=1 hardcoded in store logic) + stride_y_q: tl.constexpr = K_out + stride_y_p: tl.constexpr = Q * K_out + stride_y_n: tl.constexpr = P * Q * K_out + + pid = tl.program_id(axis=0) + + num_pid_m = tl.cdiv(M_total, BLOCK_M) + num_pid_n = tl.cdiv(K_out, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m >= num_pid_m: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + kout_mask = offs_n < K_out + + # Decode (n, p, q) from linear index + n_idx = offs_m[:, None] // (P * Q) + pq = offs_m[:, None] % (P * Q) + p_idx = pq // Q + q_idx = pq % Q + n_valid = n_idx < N + + # Precompute base positions + base_oh = p_idx * stride_h - pad_h + base_ow = q_idx * stride_w - pad_w + stride_dh = dil_h * stride_x_h + stride_dw = dil_w * stride_x_w + x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w + + # Weight base: W3[K_out, 9, C_pad] + w_base = W3 + offs_n[None, :] * stride_w3_kout + + Y_ptrs = ( + Y + + n_idx * stride_y_n + + offs_n[None, :] + + p_idx * stride_y_p + + q_idx * stride_y_q + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_c = tl.arange(0, BLOCK_C) + + for r in tl.static_range(3): + oh = base_oh + r * dil_h + valid_oh = n_valid & (oh >= 0) & (oh < H) + x_off_r = r * stride_dh + for s in tl.static_range(3): + rs_idx = r * 3 + s + ow = base_ow + s * dil_w + valid = valid_oh & (ow >= 0) & (ow < W_in) + + for c0 in range(0, C_pad, BLOCK_C): + c_offs = c0 + offs_c + c_mask = c_offs < C + + x_ptrs = x_base + c_offs[None, :] + x_off_r + s * stride_dw + w_ptrs = w_base + rs_idx * stride_w3_rs + c_offs[:, None] * stride_w3_c + + x_tile = tl.load(x_ptrs, mask=valid & c_mask[None, :], other=0.0) + w_tile = tl.load( + w_ptrs, mask=c_mask[:, None] & kout_mask[None, :], other=0.0 + ) + acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32) + + # Epilogue: bias + activation + store + if HAS_BIAS: + b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0) + acc += b[None, :] + + if ACT_TYPE == 1: + acc = tl.maximum(acc, 0) + elif ACT_TYPE == 2: + acc = tl.minimum(tl.maximum(acc, 0), 6) + elif ACT_TYPE == 3: + acc = ( + 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc))) + ) + + tl.store( + Y_ptrs, + acc, + mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]), + ) + + +@triton.autotune( + configs=AUTOTUNE_3x3_CBLOCKED_CONFIGS, + key=["M_total", "K_out", "C_pad"], + reset_to_zero=["Y"], + warmup=50, + rep=200, + cache_results=True, +) +@triton.jit +def _conv2d_3x3_cblocked_kernel( + X, + W3, + BIAS, + Y, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + K_out: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + C_pad: tl.constexpr, + Cb: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + dil_h: tl.constexpr, + dil_w: tl.constexpr, + M_total: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_C: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, +): + """Specialized 3x3 kernel for channel-blocked [N, C_blocks, H, W, Cb] input. + stride_c_local=1 is hardcoded so the compiler emits coalesced vector loads.""" + # X layout: [N, C_blocks, H, W_in, Cb] where C_blocks = C_pad // Cb + stride_x_w: tl.constexpr = Cb + stride_x_h: tl.constexpr = W_in * Cb + stride_x_cblock: tl.constexpr = H * W_in * Cb + stride_x_n: tl.constexpr = (C_pad // Cb) * H * W_in * Cb + # W3 layout: [K_out, 9, C_pad] contiguous + stride_w3_c: tl.constexpr = 1 + stride_w3_rs: tl.constexpr = C_pad + stride_w3_kout: tl.constexpr = 9 * C_pad + # Y layout: [N, K_out, P, Q] contiguous NCHW + stride_y_q: tl.constexpr = 1 + stride_y_p: tl.constexpr = Q + stride_y_k: tl.constexpr = P * Q + stride_y_n: tl.constexpr = K_out * P * Q + + pid = tl.program_id(axis=0) + + num_pid_m = tl.cdiv(M_total, BLOCK_M) + num_pid_n = tl.cdiv(K_out, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m >= num_pid_m: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + kout_mask = offs_n < K_out + + # Decode (n, p, q) from linear index + n_idx = offs_m[:, None] // (P * Q) + pq = offs_m[:, None] % (P * Q) + p_idx = pq // Q + q_idx = pq % Q + n_valid = n_idx < N + + # Precompute base positions + base_oh = p_idx * stride_h - pad_h + base_ow = q_idx * stride_w - pad_w + stride_dh = dil_h * stride_x_h + stride_dw = dil_w * stride_x_w + + # x_base for channel-blocked layout: X[n, cblock, h, w, c_local] + # base pointer accounts for n, h, w + x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w + + # Weight base: W3[K_out, 9, C_pad] + w_base = W3 + offs_n[None, :] * stride_w3_kout + + # Y pointers + Y_ptrs = ( + Y + + n_idx * stride_y_n + + offs_n[None, :] * stride_y_k + + p_idx * stride_y_p + + q_idx * stride_y_q + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_c = tl.arange(0, BLOCK_C) + + for r in tl.static_range(3): + oh = base_oh + r * dil_h + valid_oh = n_valid & (oh >= 0) & (oh < H) + x_off_r = r * stride_dh + for s in tl.static_range(3): + rs_idx = r * 3 + s + ow = base_ow + s * dil_w + valid = valid_oh & (ow >= 0) & (ow < W_in) + + for c0 in range(0, C_pad, BLOCK_C): + c_offs = c0 + offs_c + c_mask = c_offs < C + + # Compute cblock index and local offset within block + cblock_idx = c_offs // Cb + c_local = c_offs % Cb + + x_ptrs = ( + x_base + + cblock_idx[None, :] * stride_x_cblock + + c_local[None, :] + + x_off_r + + s * stride_dw + ) + w_ptrs = w_base + rs_idx * stride_w3_rs + c_offs[:, None] * stride_w3_c + + x_tile = tl.load(x_ptrs, mask=valid & c_mask[None, :], other=0.0) + w_tile = tl.load( + w_ptrs, mask=c_mask[:, None] & kout_mask[None, :], other=0.0 + ) + acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32) + + # Epilogue: bias + activation + store + if HAS_BIAS: + b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0) + acc += b[None, :] + + if ACT_TYPE == 1: + acc = tl.maximum(acc, 0) + elif ACT_TYPE == 2: + acc = tl.minimum(tl.maximum(acc, 0), 6) + elif ACT_TYPE == 3: + acc = ( + 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc))) + ) + + tl.store( + Y_ptrs, + acc, + mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]), + ) diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py b/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py new file mode 100644 index 0000000000..728d4bad59 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/conv_3x3_winograd_f4x3.py @@ -0,0 +1,1480 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import triton +import triton.language as tl +from .helpers import ( + _tanh, + AUTOTUNE_WINO4_INPUT_CONFIGS, + AUTOTUNE_WINO_GEMM_CONFIGS, + AUTOTUNE_WINO4_OUTPUT_CONFIGS, + AUTOTUNE_FUSED_F4X3_CONFIGS, +) + + +@triton.autotune( + configs=AUTOTUNE_WINO4_INPUT_CONFIGS, + key=["T", "C_pad"], + cache_results=True, +) +@triton.jit +def _winograd_f4x3_input_transform_kernel( + X, + V, + N: tl.constexpr, + C: tl.constexpr, + C_pad: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + tile_H: tl.constexpr, + tile_W: tl.constexpr, + T: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + BLOCK_C: tl.constexpr, + INPUT_DTYPE: tl.constexpr = tl.float16, + LAYOUT: tl.constexpr = 0, +): + # X layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC + if LAYOUT == 0: + stride_x_w: tl.constexpr = 1 + stride_x_h: tl.constexpr = W_in + stride_x_c: tl.constexpr = H * W_in + stride_x_n: tl.constexpr = C * H * W_in + else: + stride_x_w: tl.constexpr = C + stride_x_h: tl.constexpr = W_in * C + stride_x_c: tl.constexpr = 1 + stride_x_n: tl.constexpr = H * W_in * C + # V layout: [36, T, C_pad] contiguous + stride_v_c: tl.constexpr = 1 + stride_v_tile: tl.constexpr = C_pad + stride_v_alpha: tl.constexpr = T * C_pad + + tile_idx = tl.program_id(0) + c_block = tl.program_id(1) + + n = tile_idx // (tile_H * tile_W) + rem = tile_idx % (tile_H * tile_W) + th = rem // tile_W + tw = rem % tile_W + + h_start = th * 4 - pad_h + w_start = tw * 4 - pad_w + + offs_c = c_block * BLOCK_C + tl.arange(0, BLOCK_C) + c_mask = offs_c < C + + base = X + n * stride_x_n + offs_c * stride_x_c + n_valid = n < N + + # Load 6x6 patch + d00 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d01 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d02 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d03 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d04 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d05 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d10 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d11 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d12 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d13 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d14 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d15 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d20 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d21 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d22 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d23 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d24 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d25 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d30 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d31 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d32 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d33 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d34 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d35 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d40 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d41 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d42 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d43 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d44 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d45 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d50 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d51 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d52 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d53 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d54 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d55 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + + for r in tl.static_range(6): + h = h_start + r + h_valid = n_valid & (h >= 0) & (h < H) + for s in tl.static_range(6): + w = w_start + s + valid = h_valid & (w >= 0) & (w < W_in) + ptr = base + h * stride_x_h + w * stride_x_w + val = tl.load(ptr, mask=valid & c_mask, other=0.0) + if r == 0: + if s == 0: + d00 = val + elif s == 1: + d01 = val + elif s == 2: + d02 = val + elif s == 3: + d03 = val + elif s == 4: + d04 = val + else: + d05 = val + elif r == 1: + if s == 0: + d10 = val + elif s == 1: + d11 = val + elif s == 2: + d12 = val + elif s == 3: + d13 = val + elif s == 4: + d14 = val + else: + d15 = val + elif r == 2: + if s == 0: + d20 = val + elif s == 1: + d21 = val + elif s == 2: + d22 = val + elif s == 3: + d23 = val + elif s == 4: + d24 = val + else: + d25 = val + elif r == 3: + if s == 0: + d30 = val + elif s == 1: + d31 = val + elif s == 2: + d32 = val + elif s == 3: + d33 = val + elif s == 4: + d34 = val + else: + d35 = val + elif r == 4: + if s == 0: + d40 = val + elif s == 1: + d41 = val + elif s == 2: + d42 = val + elif s == 3: + d43 = val + elif s == 4: + d44 = val + else: + d45 = val + else: + if s == 0: + d50 = val + elif s == 1: + d51 = val + elif s == 2: + d52 = val + elif s == 3: + d53 = val + elif s == 4: + d54 = val + else: + d55 = val + + # B^T column transform (6x6): + # B^T = [[ 4, 0, -5, 0, 1, 0], + # [ 0, -4, -4, 1, 1, 0], + # [ 0, 4, -4, -1, 1, 0], + # [ 0, -2, -1, 2, 1, 0], + # [ 0, 2, -1, -2, 1, 0], + # [ 0, 4, 0, -5, 0, 1]] + # Apply to each column (compute t[row][col] from d[row][col]) + # Float32 for transform arithmetic to avoid fp16 overflow with multipliers like 4, 5 + d00f = d00.to(tl.float32) + d01f = d01.to(tl.float32) + d02f = d02.to(tl.float32) + d03f = d03.to(tl.float32) + d04f = d04.to(tl.float32) + d05f = d05.to(tl.float32) + d10f = d10.to(tl.float32) + d11f = d11.to(tl.float32) + d12f = d12.to(tl.float32) + d13f = d13.to(tl.float32) + d14f = d14.to(tl.float32) + d15f = d15.to(tl.float32) + d20f = d20.to(tl.float32) + d21f = d21.to(tl.float32) + d22f = d22.to(tl.float32) + d23f = d23.to(tl.float32) + d24f = d24.to(tl.float32) + d25f = d25.to(tl.float32) + d30f = d30.to(tl.float32) + d31f = d31.to(tl.float32) + d32f = d32.to(tl.float32) + d33f = d33.to(tl.float32) + d34f = d34.to(tl.float32) + d35f = d35.to(tl.float32) + d40f = d40.to(tl.float32) + d41f = d41.to(tl.float32) + d42f = d42.to(tl.float32) + d43f = d43.to(tl.float32) + d44f = d44.to(tl.float32) + d45f = d45.to(tl.float32) + d50f = d50.to(tl.float32) + d51f = d51.to(tl.float32) + d52f = d52.to(tl.float32) + d53f = d53.to(tl.float32) + d54f = d54.to(tl.float32) + d55f = d55.to(tl.float32) + + # Column transform: for each column s, t[row][s] = B^T @ d[:,s] + t00 = 4 * d00f - 5 * d20f + d40f + t01 = 4 * d01f - 5 * d21f + d41f + t02 = 4 * d02f - 5 * d22f + d42f + t03 = 4 * d03f - 5 * d23f + d43f + t04 = 4 * d04f - 5 * d24f + d44f + t05 = 4 * d05f - 5 * d25f + d45f + + t10 = -4 * d10f - 4 * d20f + d30f + d40f + t11 = -4 * d11f - 4 * d21f + d31f + d41f + t12 = -4 * d12f - 4 * d22f + d32f + d42f + t13 = -4 * d13f - 4 * d23f + d33f + d43f + t14 = -4 * d14f - 4 * d24f + d34f + d44f + t15 = -4 * d15f - 4 * d25f + d35f + d45f + + t20 = 4 * d10f - 4 * d20f - d30f + d40f + t21 = 4 * d11f - 4 * d21f - d31f + d41f + t22 = 4 * d12f - 4 * d22f - d32f + d42f + t23 = 4 * d13f - 4 * d23f - d33f + d43f + t24 = 4 * d14f - 4 * d24f - d34f + d44f + t25 = 4 * d15f - 4 * d25f - d35f + d45f + + t30 = -2 * d10f - d20f + 2 * d30f + d40f + t31 = -2 * d11f - d21f + 2 * d31f + d41f + t32 = -2 * d12f - d22f + 2 * d32f + d42f + t33 = -2 * d13f - d23f + 2 * d33f + d43f + t34 = -2 * d14f - d24f + 2 * d34f + d44f + t35 = -2 * d15f - d25f + 2 * d35f + d45f + + t40 = 2 * d10f - d20f - 2 * d30f + d40f + t41 = 2 * d11f - d21f - 2 * d31f + d41f + t42 = 2 * d12f - d22f - 2 * d32f + d42f + t43 = 2 * d13f - d23f - 2 * d33f + d43f + t44 = 2 * d14f - d24f - 2 * d34f + d44f + t45 = 2 * d15f - d25f - 2 * d35f + d45f + + t50 = 4 * d10f - 5 * d30f + d50f + t51 = 4 * d11f - 5 * d31f + d51f + t52 = 4 * d12f - 5 * d32f + d52f + t53 = 4 * d13f - 5 * d33f + d53f + t54 = 4 * d14f - 5 * d34f + d54f + t55 = 4 * d15f - 5 * d35f + d55f + + # Row transform: v[r][col] = B^T applied to row t[r][:] + v00 = 4 * t00 - 5 * t02 + t04 + v01 = -4 * t01 - 4 * t02 + t03 + t04 + v02 = 4 * t01 - 4 * t02 - t03 + t04 + v03 = -2 * t01 - t02 + 2 * t03 + t04 + v04 = 2 * t01 - t02 - 2 * t03 + t04 + v05 = 4 * t01 - 5 * t03 + t05 + + v10 = 4 * t10 - 5 * t12 + t14 + v11 = -4 * t11 - 4 * t12 + t13 + t14 + v12 = 4 * t11 - 4 * t12 - t13 + t14 + v13 = -2 * t11 - t12 + 2 * t13 + t14 + v14 = 2 * t11 - t12 - 2 * t13 + t14 + v15 = 4 * t11 - 5 * t13 + t15 + + v20 = 4 * t20 - 5 * t22 + t24 + v21 = -4 * t21 - 4 * t22 + t23 + t24 + v22 = 4 * t21 - 4 * t22 - t23 + t24 + v23 = -2 * t21 - t22 + 2 * t23 + t24 + v24 = 2 * t21 - t22 - 2 * t23 + t24 + v25 = 4 * t21 - 5 * t23 + t25 + + v30 = 4 * t30 - 5 * t32 + t34 + v31 = -4 * t31 - 4 * t32 + t33 + t34 + v32 = 4 * t31 - 4 * t32 - t33 + t34 + v33 = -2 * t31 - t32 + 2 * t33 + t34 + v34 = 2 * t31 - t32 - 2 * t33 + t34 + v35 = 4 * t31 - 5 * t33 + t35 + + v40 = 4 * t40 - 5 * t42 + t44 + v41 = -4 * t41 - 4 * t42 + t43 + t44 + v42 = 4 * t41 - 4 * t42 - t43 + t44 + v43 = -2 * t41 - t42 + 2 * t43 + t44 + v44 = 2 * t41 - t42 - 2 * t43 + t44 + v45 = 4 * t41 - 5 * t43 + t45 + + v50 = 4 * t50 - 5 * t52 + t54 + v51 = -4 * t51 - 4 * t52 + t53 + t54 + v52 = 4 * t51 - 4 * t52 - t53 + t54 + v53 = -2 * t51 - t52 + 2 * t53 + t54 + v54 = 2 * t51 - t52 - 2 * t53 + t54 + v55 = 4 * t51 - 5 * t53 + t55 + + v_base = V + tile_idx * stride_v_tile + offs_c * stride_v_c + c_store_mask = offs_c < C_pad + + tl.store(v_base + 0 * stride_v_alpha, v00, mask=c_store_mask) + tl.store(v_base + 1 * stride_v_alpha, v01, mask=c_store_mask) + tl.store(v_base + 2 * stride_v_alpha, v02, mask=c_store_mask) + tl.store(v_base + 3 * stride_v_alpha, v03, mask=c_store_mask) + tl.store(v_base + 4 * stride_v_alpha, v04, mask=c_store_mask) + tl.store(v_base + 5 * stride_v_alpha, v05, mask=c_store_mask) + tl.store(v_base + 6 * stride_v_alpha, v10, mask=c_store_mask) + tl.store(v_base + 7 * stride_v_alpha, v11, mask=c_store_mask) + tl.store(v_base + 8 * stride_v_alpha, v12, mask=c_store_mask) + tl.store(v_base + 9 * stride_v_alpha, v13, mask=c_store_mask) + tl.store(v_base + 10 * stride_v_alpha, v14, mask=c_store_mask) + tl.store(v_base + 11 * stride_v_alpha, v15, mask=c_store_mask) + tl.store(v_base + 12 * stride_v_alpha, v20, mask=c_store_mask) + tl.store(v_base + 13 * stride_v_alpha, v21, mask=c_store_mask) + tl.store(v_base + 14 * stride_v_alpha, v22, mask=c_store_mask) + tl.store(v_base + 15 * stride_v_alpha, v23, mask=c_store_mask) + tl.store(v_base + 16 * stride_v_alpha, v24, mask=c_store_mask) + tl.store(v_base + 17 * stride_v_alpha, v25, mask=c_store_mask) + tl.store(v_base + 18 * stride_v_alpha, v30, mask=c_store_mask) + tl.store(v_base + 19 * stride_v_alpha, v31, mask=c_store_mask) + tl.store(v_base + 20 * stride_v_alpha, v32, mask=c_store_mask) + tl.store(v_base + 21 * stride_v_alpha, v33, mask=c_store_mask) + tl.store(v_base + 22 * stride_v_alpha, v34, mask=c_store_mask) + tl.store(v_base + 23 * stride_v_alpha, v35, mask=c_store_mask) + tl.store(v_base + 24 * stride_v_alpha, v40, mask=c_store_mask) + tl.store(v_base + 25 * stride_v_alpha, v41, mask=c_store_mask) + tl.store(v_base + 26 * stride_v_alpha, v42, mask=c_store_mask) + tl.store(v_base + 27 * stride_v_alpha, v43, mask=c_store_mask) + tl.store(v_base + 28 * stride_v_alpha, v44, mask=c_store_mask) + tl.store(v_base + 29 * stride_v_alpha, v45, mask=c_store_mask) + tl.store(v_base + 30 * stride_v_alpha, v50, mask=c_store_mask) + tl.store(v_base + 31 * stride_v_alpha, v51, mask=c_store_mask) + tl.store(v_base + 32 * stride_v_alpha, v52, mask=c_store_mask) + tl.store(v_base + 33 * stride_v_alpha, v53, mask=c_store_mask) + tl.store(v_base + 34 * stride_v_alpha, v54, mask=c_store_mask) + tl.store(v_base + 35 * stride_v_alpha, v55, mask=c_store_mask) + + +@triton.autotune( + configs=AUTOTUNE_WINO4_INPUT_CONFIGS, + key=["T", "C_pad"], + cache_results=True, +) +@triton.jit +def _winograd_f4x3_cblocked_input_transform_kernel( + X, + V, + N: tl.constexpr, + C: tl.constexpr, + C_pad: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + tile_H: tl.constexpr, + tile_W: tl.constexpr, + T: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + Cb: tl.constexpr, + BLOCK_C: tl.constexpr, + INPUT_DTYPE: tl.constexpr = tl.float16, +): + # X layout: [N, C_blocks, H, W_in, Cb] where C_blocks = C_pad // Cb + stride_x_w: tl.constexpr = Cb + stride_x_h: tl.constexpr = W_in * Cb + stride_x_cblock: tl.constexpr = H * W_in * Cb + stride_x_n: tl.constexpr = (C_pad // Cb) * H * W_in * Cb + # V layout: [36, T, C_pad] contiguous + stride_v_c: tl.constexpr = 1 + stride_v_tile: tl.constexpr = C_pad + stride_v_alpha: tl.constexpr = T * C_pad + + tile_idx = tl.program_id(0) + c_block = tl.program_id(1) + + n = tile_idx // (tile_H * tile_W) + rem = tile_idx % (tile_H * tile_W) + th = rem // tile_W + tw = rem % tile_W + + h_start = th * 4 - pad_h + w_start = tw * 4 - pad_w + + offs_c = c_block * BLOCK_C + tl.arange(0, BLOCK_C) + c_mask = offs_c < C + + # NCHWc addressing: cblock_idx = offs_c // Cb, c_local = offs_c % Cb + cblock_idx = offs_c // Cb + c_local = offs_c % Cb + base = X + n * stride_x_n + cblock_idx * stride_x_cblock + c_local + n_valid = n < N + + # Load 6x6 patch — 36 values per channel + d00 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d01 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d02 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d03 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d04 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d05 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d10 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d11 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d12 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d13 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d14 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d15 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d20 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d21 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d22 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d23 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d24 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d25 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d30 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d31 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d32 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d33 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d34 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d35 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d40 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d41 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d42 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d43 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d44 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d45 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d50 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d51 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d52 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d53 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d54 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + d55 = tl.zeros((BLOCK_C,), dtype=INPUT_DTYPE) + + for r in tl.static_range(6): + h = h_start + r + h_valid = n_valid & (h >= 0) & (h < H) + for s in tl.static_range(6): + w = w_start + s + valid = h_valid & (w >= 0) & (w < W_in) + ptr = base + h * stride_x_h + w * stride_x_w + val = tl.load(ptr, mask=valid & c_mask, other=0.0) + if r == 0: + if s == 0: + d00 = val + elif s == 1: + d01 = val + elif s == 2: + d02 = val + elif s == 3: + d03 = val + elif s == 4: + d04 = val + else: + d05 = val + elif r == 1: + if s == 0: + d10 = val + elif s == 1: + d11 = val + elif s == 2: + d12 = val + elif s == 3: + d13 = val + elif s == 4: + d14 = val + else: + d15 = val + elif r == 2: + if s == 0: + d20 = val + elif s == 1: + d21 = val + elif s == 2: + d22 = val + elif s == 3: + d23 = val + elif s == 4: + d24 = val + else: + d25 = val + elif r == 3: + if s == 0: + d30 = val + elif s == 1: + d31 = val + elif s == 2: + d32 = val + elif s == 3: + d33 = val + elif s == 4: + d34 = val + else: + d35 = val + elif r == 4: + if s == 0: + d40 = val + elif s == 1: + d41 = val + elif s == 2: + d42 = val + elif s == 3: + d43 = val + elif s == 4: + d44 = val + else: + d45 = val + else: + if s == 0: + d50 = val + elif s == 1: + d51 = val + elif s == 2: + d52 = val + elif s == 3: + d53 = val + elif s == 4: + d54 = val + else: + d55 = val + + d00f = d00.to(tl.float32) + d01f = d01.to(tl.float32) + d02f = d02.to(tl.float32) + d03f = d03.to(tl.float32) + d04f = d04.to(tl.float32) + d05f = d05.to(tl.float32) + d10f = d10.to(tl.float32) + d11f = d11.to(tl.float32) + d12f = d12.to(tl.float32) + d13f = d13.to(tl.float32) + d14f = d14.to(tl.float32) + d15f = d15.to(tl.float32) + d20f = d20.to(tl.float32) + d21f = d21.to(tl.float32) + d22f = d22.to(tl.float32) + d23f = d23.to(tl.float32) + d24f = d24.to(tl.float32) + d25f = d25.to(tl.float32) + d30f = d30.to(tl.float32) + d31f = d31.to(tl.float32) + d32f = d32.to(tl.float32) + d33f = d33.to(tl.float32) + d34f = d34.to(tl.float32) + d35f = d35.to(tl.float32) + d40f = d40.to(tl.float32) + d41f = d41.to(tl.float32) + d42f = d42.to(tl.float32) + d43f = d43.to(tl.float32) + d44f = d44.to(tl.float32) + d45f = d45.to(tl.float32) + d50f = d50.to(tl.float32) + d51f = d51.to(tl.float32) + d52f = d52.to(tl.float32) + d53f = d53.to(tl.float32) + d54f = d54.to(tl.float32) + d55f = d55.to(tl.float32) + + t00 = 4 * d00f - 5 * d20f + d40f + t01 = 4 * d01f - 5 * d21f + d41f + t02 = 4 * d02f - 5 * d22f + d42f + t03 = 4 * d03f - 5 * d23f + d43f + t04 = 4 * d04f - 5 * d24f + d44f + t05 = 4 * d05f - 5 * d25f + d45f + + t10 = -4 * d10f - 4 * d20f + d30f + d40f + t11 = -4 * d11f - 4 * d21f + d31f + d41f + t12 = -4 * d12f - 4 * d22f + d32f + d42f + t13 = -4 * d13f - 4 * d23f + d33f + d43f + t14 = -4 * d14f - 4 * d24f + d34f + d44f + t15 = -4 * d15f - 4 * d25f + d35f + d45f + + t20 = 4 * d10f - 4 * d20f - d30f + d40f + t21 = 4 * d11f - 4 * d21f - d31f + d41f + t22 = 4 * d12f - 4 * d22f - d32f + d42f + t23 = 4 * d13f - 4 * d23f - d33f + d43f + t24 = 4 * d14f - 4 * d24f - d34f + d44f + t25 = 4 * d15f - 4 * d25f - d35f + d45f + + t30 = -2 * d10f - d20f + 2 * d30f + d40f + t31 = -2 * d11f - d21f + 2 * d31f + d41f + t32 = -2 * d12f - d22f + 2 * d32f + d42f + t33 = -2 * d13f - d23f + 2 * d33f + d43f + t34 = -2 * d14f - d24f + 2 * d34f + d44f + t35 = -2 * d15f - d25f + 2 * d35f + d45f + + t40 = 2 * d10f - d20f - 2 * d30f + d40f + t41 = 2 * d11f - d21f - 2 * d31f + d41f + t42 = 2 * d12f - d22f - 2 * d32f + d42f + t43 = 2 * d13f - d23f - 2 * d33f + d43f + t44 = 2 * d14f - d24f - 2 * d34f + d44f + t45 = 2 * d15f - d25f - 2 * d35f + d45f + + t50 = 4 * d10f - 5 * d30f + d50f + t51 = 4 * d11f - 5 * d31f + d51f + t52 = 4 * d12f - 5 * d32f + d52f + t53 = 4 * d13f - 5 * d33f + d53f + t54 = 4 * d14f - 5 * d34f + d54f + t55 = 4 * d15f - 5 * d35f + d55f + + v00 = 4 * t00 - 5 * t02 + t04 + v01 = -4 * t01 - 4 * t02 + t03 + t04 + v02 = 4 * t01 - 4 * t02 - t03 + t04 + v03 = -2 * t01 - t02 + 2 * t03 + t04 + v04 = 2 * t01 - t02 - 2 * t03 + t04 + v05 = 4 * t01 - 5 * t03 + t05 + + v10 = 4 * t10 - 5 * t12 + t14 + v11 = -4 * t11 - 4 * t12 + t13 + t14 + v12 = 4 * t11 - 4 * t12 - t13 + t14 + v13 = -2 * t11 - t12 + 2 * t13 + t14 + v14 = 2 * t11 - t12 - 2 * t13 + t14 + v15 = 4 * t11 - 5 * t13 + t15 + + v20 = 4 * t20 - 5 * t22 + t24 + v21 = -4 * t21 - 4 * t22 + t23 + t24 + v22 = 4 * t21 - 4 * t22 - t23 + t24 + v23 = -2 * t21 - t22 + 2 * t23 + t24 + v24 = 2 * t21 - t22 - 2 * t23 + t24 + v25 = 4 * t21 - 5 * t23 + t25 + + v30 = 4 * t30 - 5 * t32 + t34 + v31 = -4 * t31 - 4 * t32 + t33 + t34 + v32 = 4 * t31 - 4 * t32 - t33 + t34 + v33 = -2 * t31 - t32 + 2 * t33 + t34 + v34 = 2 * t31 - t32 - 2 * t33 + t34 + v35 = 4 * t31 - 5 * t33 + t35 + + v40 = 4 * t40 - 5 * t42 + t44 + v41 = -4 * t41 - 4 * t42 + t43 + t44 + v42 = 4 * t41 - 4 * t42 - t43 + t44 + v43 = -2 * t41 - t42 + 2 * t43 + t44 + v44 = 2 * t41 - t42 - 2 * t43 + t44 + v45 = 4 * t41 - 5 * t43 + t45 + + v50 = 4 * t50 - 5 * t52 + t54 + v51 = -4 * t51 - 4 * t52 + t53 + t54 + v52 = 4 * t51 - 4 * t52 - t53 + t54 + v53 = -2 * t51 - t52 + 2 * t53 + t54 + v54 = 2 * t51 - t52 - 2 * t53 + t54 + v55 = 4 * t51 - 5 * t53 + t55 + + v_base = V + tile_idx * stride_v_tile + offs_c * stride_v_c + c_store_mask = offs_c < C_pad + + tl.store(v_base + 0 * stride_v_alpha, v00, mask=c_store_mask) + tl.store(v_base + 1 * stride_v_alpha, v01, mask=c_store_mask) + tl.store(v_base + 2 * stride_v_alpha, v02, mask=c_store_mask) + tl.store(v_base + 3 * stride_v_alpha, v03, mask=c_store_mask) + tl.store(v_base + 4 * stride_v_alpha, v04, mask=c_store_mask) + tl.store(v_base + 5 * stride_v_alpha, v05, mask=c_store_mask) + tl.store(v_base + 6 * stride_v_alpha, v10, mask=c_store_mask) + tl.store(v_base + 7 * stride_v_alpha, v11, mask=c_store_mask) + tl.store(v_base + 8 * stride_v_alpha, v12, mask=c_store_mask) + tl.store(v_base + 9 * stride_v_alpha, v13, mask=c_store_mask) + tl.store(v_base + 10 * stride_v_alpha, v14, mask=c_store_mask) + tl.store(v_base + 11 * stride_v_alpha, v15, mask=c_store_mask) + tl.store(v_base + 12 * stride_v_alpha, v20, mask=c_store_mask) + tl.store(v_base + 13 * stride_v_alpha, v21, mask=c_store_mask) + tl.store(v_base + 14 * stride_v_alpha, v22, mask=c_store_mask) + tl.store(v_base + 15 * stride_v_alpha, v23, mask=c_store_mask) + tl.store(v_base + 16 * stride_v_alpha, v24, mask=c_store_mask) + tl.store(v_base + 17 * stride_v_alpha, v25, mask=c_store_mask) + tl.store(v_base + 18 * stride_v_alpha, v30, mask=c_store_mask) + tl.store(v_base + 19 * stride_v_alpha, v31, mask=c_store_mask) + tl.store(v_base + 20 * stride_v_alpha, v32, mask=c_store_mask) + tl.store(v_base + 21 * stride_v_alpha, v33, mask=c_store_mask) + tl.store(v_base + 22 * stride_v_alpha, v34, mask=c_store_mask) + tl.store(v_base + 23 * stride_v_alpha, v35, mask=c_store_mask) + tl.store(v_base + 24 * stride_v_alpha, v40, mask=c_store_mask) + tl.store(v_base + 25 * stride_v_alpha, v41, mask=c_store_mask) + tl.store(v_base + 26 * stride_v_alpha, v42, mask=c_store_mask) + tl.store(v_base + 27 * stride_v_alpha, v43, mask=c_store_mask) + tl.store(v_base + 28 * stride_v_alpha, v44, mask=c_store_mask) + tl.store(v_base + 29 * stride_v_alpha, v45, mask=c_store_mask) + tl.store(v_base + 30 * stride_v_alpha, v50, mask=c_store_mask) + tl.store(v_base + 31 * stride_v_alpha, v51, mask=c_store_mask) + tl.store(v_base + 32 * stride_v_alpha, v52, mask=c_store_mask) + tl.store(v_base + 33 * stride_v_alpha, v53, mask=c_store_mask) + tl.store(v_base + 34 * stride_v_alpha, v54, mask=c_store_mask) + tl.store(v_base + 35 * stride_v_alpha, v55, mask=c_store_mask) + + +@triton.autotune( + configs=AUTOTUNE_WINO_GEMM_CONFIGS, + key=["T", "K_out", "C_pad"], + cache_results=True, +) +@triton.jit +def _winograd_f4x3_batched_gemm_kernel( + V, + U, + M_out, + T: tl.constexpr, + K_out: tl.constexpr, + C_pad: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Batched GEMM: M[alpha] = V[alpha] @ U[alpha]^T, alpha in [0..36)""" + # V layout: [36, T, C_pad] contiguous + stride_v_c: tl.constexpr = 1 + stride_v_tile: tl.constexpr = C_pad + stride_v_alpha: tl.constexpr = T * C_pad + # U layout: [36, K_out, C_pad] contiguous + stride_u_c: tl.constexpr = 1 + stride_u_k: tl.constexpr = C_pad + stride_u_alpha: tl.constexpr = K_out * C_pad + # M layout: [36, T, K_out] contiguous + stride_m_k: tl.constexpr = 1 + stride_m_tile: tl.constexpr = K_out + stride_m_alpha: tl.constexpr = T * K_out + + pid = tl.program_id(0) + alpha = tl.program_id(1) + + num_pid_m = tl.cdiv(T, BLOCK_M) + num_pid_n = tl.cdiv(K_out, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m >= num_pid_m or pid_n >= num_pid_n: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + v_base = V + alpha * stride_v_alpha + u_base = U + alpha * stride_u_alpha + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k0 in range(0, C_pad, BLOCK_K): + k_offs = k0 + offs_k + + v_ptrs = v_base + offs_m[:, None] * stride_v_tile + k_offs[None, :] * stride_v_c + v_mask = (offs_m[:, None] < T) & (k_offs[None, :] < C_pad) + v_tile = tl.load(v_ptrs, mask=v_mask, other=0.0) + + u_ptrs = u_base + offs_n[:, None] * stride_u_k + k_offs[None, :] * stride_u_c + u_mask = (offs_n[:, None] < K_out) & (k_offs[None, :] < C_pad) + u_tile = tl.load(u_ptrs, mask=u_mask, other=0.0) + + acc += tl.dot(v_tile, tl.trans(u_tile), out_dtype=tl.float32) + + m_ptrs = ( + M_out + + alpha * stride_m_alpha + + offs_m[:, None] * stride_m_tile + + offs_n[None, :] * stride_m_k + ) + m_mask = (offs_m[:, None] < T) & (offs_n[None, :] < K_out) + tl.store(m_ptrs, acc, mask=m_mask) + + +@triton.autotune( + configs=AUTOTUNE_WINO4_OUTPUT_CONFIGS, + key=["T", "K_out"], + cache_results=True, +) +@triton.jit +def _winograd_f4x3_output_transform_kernel( + M_in, + BIAS, + Y, + N: tl.constexpr, + K_out: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + tile_H: tl.constexpr, + tile_W: tl.constexpr, + T: tl.constexpr, + BLOCK_K: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, + LAYOUT: tl.constexpr = 0, +): + # M layout: [36, T, K_out] contiguous + stride_m_k: tl.constexpr = 1 + stride_m_tile: tl.constexpr = K_out + stride_m_alpha: tl.constexpr = T * K_out + # Y layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC + if LAYOUT == 0: + stride_y_q: tl.constexpr = 1 + stride_y_p: tl.constexpr = Q + stride_y_k: tl.constexpr = P * Q + stride_y_n: tl.constexpr = K_out * P * Q + else: + stride_y_q: tl.constexpr = K_out + stride_y_p: tl.constexpr = Q * K_out + stride_y_k: tl.constexpr = 1 + stride_y_n: tl.constexpr = P * Q * K_out + + tile_idx = tl.program_id(0) + k_block = tl.program_id(1) + + n = tile_idx // (tile_H * tile_W) + rem = tile_idx % (tile_H * tile_W) + th = rem // tile_W + tw = rem % tile_W + + p_start = th * 4 + q_start = tw * 4 + + offs_k = k_block * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = offs_k < K_out + + # Load 36 values from M[alpha, tile_idx, k] + m_base = M_in + tile_idx * stride_m_tile + offs_k * stride_m_k + + m00 = tl.load(m_base + 0 * stride_m_alpha, mask=k_mask, other=0.0) + m01 = tl.load(m_base + 1 * stride_m_alpha, mask=k_mask, other=0.0) + m02 = tl.load(m_base + 2 * stride_m_alpha, mask=k_mask, other=0.0) + m03 = tl.load(m_base + 3 * stride_m_alpha, mask=k_mask, other=0.0) + m04 = tl.load(m_base + 4 * stride_m_alpha, mask=k_mask, other=0.0) + m05 = tl.load(m_base + 5 * stride_m_alpha, mask=k_mask, other=0.0) + m10 = tl.load(m_base + 6 * stride_m_alpha, mask=k_mask, other=0.0) + m11 = tl.load(m_base + 7 * stride_m_alpha, mask=k_mask, other=0.0) + m12 = tl.load(m_base + 8 * stride_m_alpha, mask=k_mask, other=0.0) + m13 = tl.load(m_base + 9 * stride_m_alpha, mask=k_mask, other=0.0) + m14 = tl.load(m_base + 10 * stride_m_alpha, mask=k_mask, other=0.0) + m15 = tl.load(m_base + 11 * stride_m_alpha, mask=k_mask, other=0.0) + m20 = tl.load(m_base + 12 * stride_m_alpha, mask=k_mask, other=0.0) + m21 = tl.load(m_base + 13 * stride_m_alpha, mask=k_mask, other=0.0) + m22 = tl.load(m_base + 14 * stride_m_alpha, mask=k_mask, other=0.0) + m23 = tl.load(m_base + 15 * stride_m_alpha, mask=k_mask, other=0.0) + m24 = tl.load(m_base + 16 * stride_m_alpha, mask=k_mask, other=0.0) + m25 = tl.load(m_base + 17 * stride_m_alpha, mask=k_mask, other=0.0) + m30 = tl.load(m_base + 18 * stride_m_alpha, mask=k_mask, other=0.0) + m31 = tl.load(m_base + 19 * stride_m_alpha, mask=k_mask, other=0.0) + m32 = tl.load(m_base + 20 * stride_m_alpha, mask=k_mask, other=0.0) + m33 = tl.load(m_base + 21 * stride_m_alpha, mask=k_mask, other=0.0) + m34 = tl.load(m_base + 22 * stride_m_alpha, mask=k_mask, other=0.0) + m35 = tl.load(m_base + 23 * stride_m_alpha, mask=k_mask, other=0.0) + m40 = tl.load(m_base + 24 * stride_m_alpha, mask=k_mask, other=0.0) + m41 = tl.load(m_base + 25 * stride_m_alpha, mask=k_mask, other=0.0) + m42 = tl.load(m_base + 26 * stride_m_alpha, mask=k_mask, other=0.0) + m43 = tl.load(m_base + 27 * stride_m_alpha, mask=k_mask, other=0.0) + m44 = tl.load(m_base + 28 * stride_m_alpha, mask=k_mask, other=0.0) + m45 = tl.load(m_base + 29 * stride_m_alpha, mask=k_mask, other=0.0) + m50 = tl.load(m_base + 30 * stride_m_alpha, mask=k_mask, other=0.0) + m51 = tl.load(m_base + 31 * stride_m_alpha, mask=k_mask, other=0.0) + m52 = tl.load(m_base + 32 * stride_m_alpha, mask=k_mask, other=0.0) + m53 = tl.load(m_base + 33 * stride_m_alpha, mask=k_mask, other=0.0) + m54 = tl.load(m_base + 34 * stride_m_alpha, mask=k_mask, other=0.0) + m55 = tl.load(m_base + 35 * stride_m_alpha, mask=k_mask, other=0.0) + + # A^T column transform (4x6): + # A^T = [[ 1, 1, 1, 1, 1, 0], + # [ 0, 1, -1, 2, -2, 0], + # [ 0, 1, 1, 4, 4, 0], + # [ 0, 1, -1, 8, -8, 1]] + # Apply to each column s: s_col = A^T @ m_col + + # Column 0 + s00 = m00 + m10 + m20 + m30 + m40 + s10 = m10 - m20 + 2 * m30 - 2 * m40 + s20 = m10 + m20 + 4 * m30 + 4 * m40 + s30 = m10 - m20 + 8 * m30 - 8 * m40 + m50 + # Column 1 + s01 = m01 + m11 + m21 + m31 + m41 + s11 = m11 - m21 + 2 * m31 - 2 * m41 + s21 = m11 + m21 + 4 * m31 + 4 * m41 + s31 = m11 - m21 + 8 * m31 - 8 * m41 + m51 + # Column 2 + s02 = m02 + m12 + m22 + m32 + m42 + s12 = m12 - m22 + 2 * m32 - 2 * m42 + s22 = m12 + m22 + 4 * m32 + 4 * m42 + s32 = m12 - m22 + 8 * m32 - 8 * m42 + m52 + # Column 3 + s03 = m03 + m13 + m23 + m33 + m43 + s13 = m13 - m23 + 2 * m33 - 2 * m43 + s23 = m13 + m23 + 4 * m33 + 4 * m43 + s33 = m13 - m23 + 8 * m33 - 8 * m43 + m53 + # Column 4 + s04 = m04 + m14 + m24 + m34 + m44 + s14 = m14 - m24 + 2 * m34 - 2 * m44 + s24 = m14 + m24 + 4 * m34 + 4 * m44 + s34 = m14 - m24 + 8 * m34 - 8 * m44 + m54 + # Column 5 + s05 = m05 + m15 + m25 + m35 + m45 + s15 = m15 - m25 + 2 * m35 - 2 * m45 + s25 = m15 + m25 + 4 * m35 + 4 * m45 + s35 = m15 - m25 + 8 * m35 - 8 * m45 + m55 + + # A^T row transform + y00 = s00 + s01 + s02 + s03 + s04 + y01 = s01 - s02 + 2 * s03 - 2 * s04 + y02 = s01 + s02 + 4 * s03 + 4 * s04 + y03 = s01 - s02 + 8 * s03 - 8 * s04 + s05 + + y10 = s10 + s11 + s12 + s13 + s14 + y11 = s11 - s12 + 2 * s13 - 2 * s14 + y12 = s11 + s12 + 4 * s13 + 4 * s14 + y13 = s11 - s12 + 8 * s13 - 8 * s14 + s15 + + y20 = s20 + s21 + s22 + s23 + s24 + y21 = s21 - s22 + 2 * s23 - 2 * s24 + y22 = s21 + s22 + 4 * s23 + 4 * s24 + y23 = s21 - s22 + 8 * s23 - 8 * s24 + s25 + + y30 = s30 + s31 + s32 + s33 + s34 + y31 = s31 - s32 + 2 * s33 - 2 * s34 + y32 = s31 + s32 + 4 * s33 + 4 * s34 + y33 = s31 - s32 + 8 * s33 - 8 * s34 + s35 + + # Bias + if HAS_BIAS: + bias = tl.load(BIAS + offs_k, mask=k_mask, other=0.0) + y00 += bias + y01 += bias + y02 += bias + y03 += bias + y10 += bias + y11 += bias + y12 += bias + y13 += bias + y20 += bias + y21 += bias + y22 += bias + y23 += bias + y30 += bias + y31 += bias + y32 += bias + y33 += bias + + # Activation + if ACT_TYPE == 1: + y00 = tl.maximum(y00, 0) + y01 = tl.maximum(y01, 0) + y02 = tl.maximum(y02, 0) + y03 = tl.maximum(y03, 0) + y10 = tl.maximum(y10, 0) + y11 = tl.maximum(y11, 0) + y12 = tl.maximum(y12, 0) + y13 = tl.maximum(y13, 0) + y20 = tl.maximum(y20, 0) + y21 = tl.maximum(y21, 0) + y22 = tl.maximum(y22, 0) + y23 = tl.maximum(y23, 0) + y30 = tl.maximum(y30, 0) + y31 = tl.maximum(y31, 0) + y32 = tl.maximum(y32, 0) + y33 = tl.maximum(y33, 0) + elif ACT_TYPE == 2: + y00 = tl.minimum(tl.maximum(y00, 0), 6) + y01 = tl.minimum(tl.maximum(y01, 0), 6) + y02 = tl.minimum(tl.maximum(y02, 0), 6) + y03 = tl.minimum(tl.maximum(y03, 0), 6) + y10 = tl.minimum(tl.maximum(y10, 0), 6) + y11 = tl.minimum(tl.maximum(y11, 0), 6) + y12 = tl.minimum(tl.maximum(y12, 0), 6) + y13 = tl.minimum(tl.maximum(y13, 0), 6) + y20 = tl.minimum(tl.maximum(y20, 0), 6) + y21 = tl.minimum(tl.maximum(y21, 0), 6) + y22 = tl.minimum(tl.maximum(y22, 0), 6) + y23 = tl.minimum(tl.maximum(y23, 0), 6) + y30 = tl.minimum(tl.maximum(y30, 0), 6) + y31 = tl.minimum(tl.maximum(y31, 0), 6) + y32 = tl.minimum(tl.maximum(y32, 0), 6) + y33 = tl.minimum(tl.maximum(y33, 0), 6) + elif ACT_TYPE == 3: + y00 = ( + 0.5 * y00 * (1.0 + _tanh(0.7978845608 * (y00 + 0.044715 * y00 * y00 * y00))) + ) + y01 = ( + 0.5 * y01 * (1.0 + _tanh(0.7978845608 * (y01 + 0.044715 * y01 * y01 * y01))) + ) + y02 = ( + 0.5 * y02 * (1.0 + _tanh(0.7978845608 * (y02 + 0.044715 * y02 * y02 * y02))) + ) + y03 = ( + 0.5 * y03 * (1.0 + _tanh(0.7978845608 * (y03 + 0.044715 * y03 * y03 * y03))) + ) + y10 = ( + 0.5 * y10 * (1.0 + _tanh(0.7978845608 * (y10 + 0.044715 * y10 * y10 * y10))) + ) + y11 = ( + 0.5 * y11 * (1.0 + _tanh(0.7978845608 * (y11 + 0.044715 * y11 * y11 * y11))) + ) + y12 = ( + 0.5 * y12 * (1.0 + _tanh(0.7978845608 * (y12 + 0.044715 * y12 * y12 * y12))) + ) + y13 = ( + 0.5 * y13 * (1.0 + _tanh(0.7978845608 * (y13 + 0.044715 * y13 * y13 * y13))) + ) + y20 = ( + 0.5 * y20 * (1.0 + _tanh(0.7978845608 * (y20 + 0.044715 * y20 * y20 * y20))) + ) + y21 = ( + 0.5 * y21 * (1.0 + _tanh(0.7978845608 * (y21 + 0.044715 * y21 * y21 * y21))) + ) + y22 = ( + 0.5 * y22 * (1.0 + _tanh(0.7978845608 * (y22 + 0.044715 * y22 * y22 * y22))) + ) + y23 = ( + 0.5 * y23 * (1.0 + _tanh(0.7978845608 * (y23 + 0.044715 * y23 * y23 * y23))) + ) + y30 = ( + 0.5 * y30 * (1.0 + _tanh(0.7978845608 * (y30 + 0.044715 * y30 * y30 * y30))) + ) + y31 = ( + 0.5 * y31 * (1.0 + _tanh(0.7978845608 * (y31 + 0.044715 * y31 * y31 * y31))) + ) + y32 = ( + 0.5 * y32 * (1.0 + _tanh(0.7978845608 * (y32 + 0.044715 * y32 * y32 * y32))) + ) + y33 = ( + 0.5 * y33 * (1.0 + _tanh(0.7978845608 * (y33 + 0.044715 * y33 * y33 * y33))) + ) + + # Store 4x4 output tile + n_valid = n < N + y_base = Y + n * stride_y_n + offs_k * stride_y_k + + if n_valid: + for r in tl.static_range(4): + p = p_start + r + if p < P: + for s in tl.static_range(4): + q = q_start + s + if q < Q: + if r == 0: + if s == 0: + val = y00 + elif s == 1: + val = y01 + elif s == 2: + val = y02 + else: + val = y03 + elif r == 1: + if s == 0: + val = y10 + elif s == 1: + val = y11 + elif s == 2: + val = y12 + else: + val = y13 + elif r == 2: + if s == 0: + val = y20 + elif s == 1: + val = y21 + elif s == 2: + val = y22 + else: + val = y23 + else: + if s == 0: + val = y30 + elif s == 1: + val = y31 + elif s == 2: + val = y32 + else: + val = y33 + tl.store( + y_base + p * stride_y_p + q * stride_y_q, val, mask=k_mask + ) + + +@triton.autotune( + configs=AUTOTUNE_FUSED_F4X3_CONFIGS, + key=["T", "K_out", "C_pad"], + cache_results=True, +) +@triton.jit +def _winograd_f4x3_fused_gemm_output_kernel( + V, + U, + BIAS, + Y, + N: tl.constexpr, + K_out: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + C_pad: tl.constexpr, + tile_H: tl.constexpr, + tile_W: tl.constexpr, + T: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_C: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, + LAYOUT: tl.constexpr = 0, +): + """Fused GEMM + output transform for Winograd F(4x4,3x3). + Processes 6 alphas per column (6 columns = 36 total) sequentially, + accumulating column-transform results to reduce register pressure.""" + # V layout: [36, T, C_pad] contiguous + stride_v_c: tl.constexpr = 1 + stride_v_tile: tl.constexpr = C_pad + stride_v_alpha: tl.constexpr = T * C_pad + # U layout: [36, K_out, C_pad] contiguous + stride_u_c: tl.constexpr = 1 + stride_u_k: tl.constexpr = C_pad + stride_u_alpha: tl.constexpr = K_out * C_pad + # Y layout: LAYOUT=0 NCHW, LAYOUT=1 NHWC + if LAYOUT == 0: + stride_y_q: tl.constexpr = 1 + stride_y_p: tl.constexpr = Q + stride_y_k: tl.constexpr = P * Q + stride_y_n: tl.constexpr = K_out * P * Q + else: + stride_y_q: tl.constexpr = K_out + stride_y_p: tl.constexpr = Q * K_out + stride_y_k: tl.constexpr = 1 + stride_y_n: tl.constexpr = P * Q * K_out + + pid_t = tl.program_id(0) + pid_k = tl.program_id(1) + + offs_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + t_mask = offs_t < T + k_mask = offs_k < K_out + + offs_c = tl.arange(0, BLOCK_C) + + # Process column by column to reduce register pressure. + # For each column s (0..5), accumulate 6 alpha GEMMs (rows 0..5), + # then apply A^T column transform (6→4), storing s[row][col] results. + # A^T = [[1,1,1,1,1,0],[0,1,-1,2,-2,0],[0,1,1,4,4,0],[0,1,-1,8,-8,1]] + # We need s[0..3][0..5] = 24 accumulators after column transform. + # Then row transform: s[row][0..5] → y[row][0..3] = 16 output values. + + # Accumulate all 24 s-values (4 rows × 6 columns) + s00 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s01 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s02 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s03 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s04 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s05 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s10 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s11 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s12 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s13 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s14 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s15 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s20 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s21 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s22 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s23 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s24 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s25 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s30 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s31 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s32 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s33 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s34 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + s35 = tl.zeros((BLOCK_T, BLOCK_K), dtype=tl.float32) + + v_mask_2d = t_mask[:, None] & (offs_c < C_pad)[None, :] + u_mask_2d = k_mask[:, None] & (offs_c < C_pad)[None, :] + + for c0 in range(0, C_pad, BLOCK_C): + c_offs = c0 + offs_c + if c0 > 0: + v_mask_2d = t_mask[:, None] & (c_offs < C_pad)[None, :] + u_mask_2d = k_mask[:, None] & (c_offs < C_pad)[None, :] + + # Process all 36 alphas, applying column transform on-the-fly + for col in tl.static_range(6): + # Load 6 V tiles (rows 0-5 for this column) and 6 U tiles + # alpha = row * 6 + col + v_base_c = ( + V + offs_t[:, None] * stride_v_tile + c_offs[None, :] * stride_v_c + ) + u_base_c = U + offs_k[:, None] * stride_u_k + c_offs[None, :] * stride_u_c + + # Load V and U for 6 rows, compute GEMM, apply column transform + v0 = tl.load( + v_base_c + (0 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u0 = tl.load( + u_base_c + (0 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m0 = tl.dot(v0, tl.trans(u0), out_dtype=tl.float32) + + v1 = tl.load( + v_base_c + (1 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u1 = tl.load( + u_base_c + (1 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m1 = tl.dot(v1, tl.trans(u1), out_dtype=tl.float32) + + v2 = tl.load( + v_base_c + (2 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u2 = tl.load( + u_base_c + (2 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m2 = tl.dot(v2, tl.trans(u2), out_dtype=tl.float32) + + v3 = tl.load( + v_base_c + (3 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u3 = tl.load( + u_base_c + (3 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m3 = tl.dot(v3, tl.trans(u3), out_dtype=tl.float32) + + v4 = tl.load( + v_base_c + (4 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u4 = tl.load( + u_base_c + (4 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m4 = tl.dot(v4, tl.trans(u4), out_dtype=tl.float32) + + v5 = tl.load( + v_base_c + (5 * 6 + col) * stride_v_alpha, mask=v_mask_2d, other=0.0 + ) + u5 = tl.load( + u_base_c + (5 * 6 + col) * stride_u_alpha, mask=u_mask_2d, other=0.0 + ) + m5 = tl.dot(v5, tl.trans(u5), out_dtype=tl.float32) + + # Column transform: A^T @ [m0..m5] → [s0, s1, s2, s3] + sc0 = m0 + m1 + m2 + m3 + m4 + sc1 = m1 - m2 + 2 * m3 - 2 * m4 + sc2 = m1 + m2 + 4 * m3 + 4 * m4 + sc3 = m1 - m2 + 8 * m3 - 8 * m4 + m5 + + # Accumulate into the appropriate column + if col == 0: + s00 += sc0 + s10 += sc1 + s20 += sc2 + s30 += sc3 + elif col == 1: + s01 += sc0 + s11 += sc1 + s21 += sc2 + s31 += sc3 + elif col == 2: + s02 += sc0 + s12 += sc1 + s22 += sc2 + s32 += sc3 + elif col == 3: + s03 += sc0 + s13 += sc1 + s23 += sc2 + s33 += sc3 + elif col == 4: + s04 += sc0 + s14 += sc1 + s24 += sc2 + s34 += sc3 + elif col == 5: + s05 += sc0 + s15 += sc1 + s25 += sc2 + s35 += sc3 + + # Row transform: A^T applied to rows + y00 = s00 + s01 + s02 + s03 + s04 + y01 = s01 - s02 + 2 * s03 - 2 * s04 + y02 = s01 + s02 + 4 * s03 + 4 * s04 + y03 = s01 - s02 + 8 * s03 - 8 * s04 + s05 + + y10 = s10 + s11 + s12 + s13 + s14 + y11 = s11 - s12 + 2 * s13 - 2 * s14 + y12 = s11 + s12 + 4 * s13 + 4 * s14 + y13 = s11 - s12 + 8 * s13 - 8 * s14 + s15 + + y20 = s20 + s21 + s22 + s23 + s24 + y21 = s21 - s22 + 2 * s23 - 2 * s24 + y22 = s21 + s22 + 4 * s23 + 4 * s24 + y23 = s21 - s22 + 8 * s23 - 8 * s24 + s25 + + y30 = s30 + s31 + s32 + s33 + s34 + y31 = s31 - s32 + 2 * s33 - 2 * s34 + y32 = s31 + s32 + 4 * s33 + 4 * s34 + y33 = s31 - s32 + 8 * s33 - 8 * s34 + s35 + + # Bias + if HAS_BIAS: + bias = tl.load(BIAS + offs_k, mask=k_mask, other=0.0) + b = bias[None, :] + y00 += b + y01 += b + y02 += b + y03 += b + y10 += b + y11 += b + y12 += b + y13 += b + y20 += b + y21 += b + y22 += b + y23 += b + y30 += b + y31 += b + y32 += b + y33 += b + + # Activation + if ACT_TYPE == 1: + y00 = tl.maximum(y00, 0) + y01 = tl.maximum(y01, 0) + y02 = tl.maximum(y02, 0) + y03 = tl.maximum(y03, 0) + y10 = tl.maximum(y10, 0) + y11 = tl.maximum(y11, 0) + y12 = tl.maximum(y12, 0) + y13 = tl.maximum(y13, 0) + y20 = tl.maximum(y20, 0) + y21 = tl.maximum(y21, 0) + y22 = tl.maximum(y22, 0) + y23 = tl.maximum(y23, 0) + y30 = tl.maximum(y30, 0) + y31 = tl.maximum(y31, 0) + y32 = tl.maximum(y32, 0) + y33 = tl.maximum(y33, 0) + elif ACT_TYPE == 2: + y00 = tl.minimum(tl.maximum(y00, 0), 6) + y01 = tl.minimum(tl.maximum(y01, 0), 6) + y02 = tl.minimum(tl.maximum(y02, 0), 6) + y03 = tl.minimum(tl.maximum(y03, 0), 6) + y10 = tl.minimum(tl.maximum(y10, 0), 6) + y11 = tl.minimum(tl.maximum(y11, 0), 6) + y12 = tl.minimum(tl.maximum(y12, 0), 6) + y13 = tl.minimum(tl.maximum(y13, 0), 6) + y20 = tl.minimum(tl.maximum(y20, 0), 6) + y21 = tl.minimum(tl.maximum(y21, 0), 6) + y22 = tl.minimum(tl.maximum(y22, 0), 6) + y23 = tl.minimum(tl.maximum(y23, 0), 6) + y30 = tl.minimum(tl.maximum(y30, 0), 6) + y31 = tl.minimum(tl.maximum(y31, 0), 6) + y32 = tl.minimum(tl.maximum(y32, 0), 6) + y33 = tl.minimum(tl.maximum(y33, 0), 6) + elif ACT_TYPE == 3: + y00 = ( + 0.5 * y00 * (1.0 + _tanh(0.7978845608 * (y00 + 0.044715 * y00 * y00 * y00))) + ) + y01 = ( + 0.5 * y01 * (1.0 + _tanh(0.7978845608 * (y01 + 0.044715 * y01 * y01 * y01))) + ) + y02 = ( + 0.5 * y02 * (1.0 + _tanh(0.7978845608 * (y02 + 0.044715 * y02 * y02 * y02))) + ) + y03 = ( + 0.5 * y03 * (1.0 + _tanh(0.7978845608 * (y03 + 0.044715 * y03 * y03 * y03))) + ) + y10 = ( + 0.5 * y10 * (1.0 + _tanh(0.7978845608 * (y10 + 0.044715 * y10 * y10 * y10))) + ) + y11 = ( + 0.5 * y11 * (1.0 + _tanh(0.7978845608 * (y11 + 0.044715 * y11 * y11 * y11))) + ) + y12 = ( + 0.5 * y12 * (1.0 + _tanh(0.7978845608 * (y12 + 0.044715 * y12 * y12 * y12))) + ) + y13 = ( + 0.5 * y13 * (1.0 + _tanh(0.7978845608 * (y13 + 0.044715 * y13 * y13 * y13))) + ) + y20 = ( + 0.5 * y20 * (1.0 + _tanh(0.7978845608 * (y20 + 0.044715 * y20 * y20 * y20))) + ) + y21 = ( + 0.5 * y21 * (1.0 + _tanh(0.7978845608 * (y21 + 0.044715 * y21 * y21 * y21))) + ) + y22 = ( + 0.5 * y22 * (1.0 + _tanh(0.7978845608 * (y22 + 0.044715 * y22 * y22 * y22))) + ) + y23 = ( + 0.5 * y23 * (1.0 + _tanh(0.7978845608 * (y23 + 0.044715 * y23 * y23 * y23))) + ) + y30 = ( + 0.5 * y30 * (1.0 + _tanh(0.7978845608 * (y30 + 0.044715 * y30 * y30 * y30))) + ) + y31 = ( + 0.5 * y31 * (1.0 + _tanh(0.7978845608 * (y31 + 0.044715 * y31 * y31 * y31))) + ) + y32 = ( + 0.5 * y32 * (1.0 + _tanh(0.7978845608 * (y32 + 0.044715 * y32 * y32 * y32))) + ) + y33 = ( + 0.5 * y33 * (1.0 + _tanh(0.7978845608 * (y33 + 0.044715 * y33 * y33 * y33))) + ) + + # Decode tile indices and compute per-tile base pointers + n_idx = offs_t // (tile_H * tile_W) + rem = offs_t % (tile_H * tile_W) + th = rem // tile_W + tw = rem % tile_W + p_start = th * 4 + q_start = tw * 4 + + y_base = Y + n_idx * stride_y_n + y_ptrs_base = y_base[:, None] + offs_k[None, :] * stride_y_k + full_mask = t_mask[:, None] & k_mask[None, :] + + # Store 4x4 output tile using 2D scatter stores + for di in tl.static_range(4): + for dj in tl.static_range(4): + mask_ij = ( + full_mask & (p_start + di < P)[:, None] & (q_start + dj < Q)[:, None] + ) + ptrs_ij = ( + y_ptrs_base + + (p_start[:, None] + di) * stride_y_p + + (q_start[:, None] + dj) * stride_y_q + ) + if di == 0: + if dj == 0: + tl.store(ptrs_ij, y00, mask=mask_ij) + elif dj == 1: + tl.store(ptrs_ij, y01, mask=mask_ij) + elif dj == 2: + tl.store(ptrs_ij, y02, mask=mask_ij) + else: + tl.store(ptrs_ij, y03, mask=mask_ij) + elif di == 1: + if dj == 0: + tl.store(ptrs_ij, y10, mask=mask_ij) + elif dj == 1: + tl.store(ptrs_ij, y11, mask=mask_ij) + elif dj == 2: + tl.store(ptrs_ij, y12, mask=mask_ij) + else: + tl.store(ptrs_ij, y13, mask=mask_ij) + elif di == 2: + if dj == 0: + tl.store(ptrs_ij, y20, mask=mask_ij) + elif dj == 1: + tl.store(ptrs_ij, y21, mask=mask_ij) + elif dj == 2: + tl.store(ptrs_ij, y22, mask=mask_ij) + else: + tl.store(ptrs_ij, y23, mask=mask_ij) + else: + if dj == 0: + tl.store(ptrs_ij, y30, mask=mask_ij) + elif dj == 1: + tl.store(ptrs_ij, y31, mask=mask_ij) + elif dj == 2: + tl.store(ptrs_ij, y32, mask=mask_ij) + else: + tl.store(ptrs_ij, y33, mask=mask_ij) diff --git a/aiter/ops/triton/_triton_kernels/conv/conv_general.py b/aiter/ops/triton/_triton_kernels/conv/conv_general.py new file mode 100644 index 0000000000..9c5ade7fd6 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/conv_general.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import triton +import triton.language as tl +from .helpers import _tanh, AUTOTUNE_GENERAL_CONFIGS + + +@triton.autotune( + configs=AUTOTUNE_GENERAL_CONFIGS, + key=["M_total", "K_out", "K_pad"], + reset_to_zero=["Y"], + warmup=50, + rep=200, + cache_results=True, +) +@triton.jit +def _conv2d_general_kernel( + X, + WK, + BIAS, + Y, + N: tl.constexpr, + C: tl.constexpr, + H: tl.constexpr, + W_in: tl.constexpr, + K_out: tl.constexpr, + R: tl.constexpr, + S: tl.constexpr, + P: tl.constexpr, + Q: tl.constexpr, + K_pad: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + pad_h: tl.constexpr, + pad_w: tl.constexpr, + dil_h: tl.constexpr, + dil_w: tl.constexpr, + M_total: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + HAS_BIAS: tl.constexpr, + ACT_TYPE: tl.constexpr, + LAYOUT: tl.constexpr, +): + """General conv kernel with precomputed bases. + LAYOUT: 0=NCHW, 1=NHWC + """ + # WK is always [K_out, K_pad] contiguous + stride_wk_kout: tl.constexpr = K_pad + stride_wk_kred: tl.constexpr = 1 + if LAYOUT == 0: + # NCHW: X[N, C, H, W_in], Y[N, K_out, P, Q] + stride_x_n: tl.constexpr = C * H * W_in + stride_x_c: tl.constexpr = H * W_in + stride_x_h: tl.constexpr = W_in + stride_x_w: tl.constexpr = 1 + stride_y_n: tl.constexpr = K_out * P * Q + stride_y_k: tl.constexpr = P * Q + stride_y_p: tl.constexpr = Q + stride_y_q: tl.constexpr = 1 + else: + # NHWC: X[N, H, W_in, C], Y[N, P, Q, K_out] + stride_x_n: tl.constexpr = H * W_in * C + stride_x_c: tl.constexpr = 1 + stride_x_h: tl.constexpr = W_in * C + stride_x_w: tl.constexpr = C + stride_y_n: tl.constexpr = P * Q * K_out + stride_y_k: tl.constexpr = 1 + stride_y_p: tl.constexpr = Q * K_out + stride_y_q: tl.constexpr = K_out + + pid = tl.program_id(axis=0) + + num_pid_m = tl.cdiv(M_total, BLOCK_M) + num_pid_n = tl.cdiv(K_out, BLOCK_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m >= num_pid_m: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + kout_mask = offs_n < K_out + + # Decode offs_m -> (n_idx, p_idx, q_idx) + n_idx = offs_m[:, None] // (P * Q) + pq = offs_m[:, None] % (P * Q) + p_idx = pq // Q + q_idx = pq % Q + + n_valid = n_idx < N + + # Precompute base positions + base_oh = p_idx * stride_h - pad_h + base_ow = q_idx * stride_w - pad_w + stride_dh = dil_h * stride_x_h + stride_dw = dil_w * stride_x_w + x_base = X + n_idx * stride_x_n + base_oh * stride_x_h + base_ow * stride_x_w + wk_base = WK + offs_n[None, :] * stride_wk_kout + + Y_ptrs = ( + Y + + n_idx * stride_y_n + + offs_n[None, :] * stride_y_k + + p_idx * stride_y_p + + q_idx * stride_y_q + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + offs_k = tl.arange(0, BLOCK_K) + rs_stride = R * S + + for k0 in range(0, K_pad, BLOCK_K): + kred = k0 + offs_k + + WK_ptrs = wk_base + kred[:, None] * stride_wk_kred + w_tile = tl.load(WK_ptrs, mask=kout_mask[None, :], other=0.0) + + c = kred // rs_stride + rs = kred % rs_stride + r = rs // S + s = rs % S + + oh = base_oh + r * dil_h + ow = base_ow + s * dil_w + + X_ptrs = x_base + c * stride_x_c + r * stride_dh + s * stride_dw + x_mask = ( + n_valid & (oh >= 0) & (ow >= 0) & (oh < H) & (ow < W_in) & (c[None, :] < C) + ) + x_tile = tl.load(X_ptrs, mask=x_mask, other=0.0) + + acc += tl.dot(x_tile, w_tile, out_dtype=tl.float32) + + if HAS_BIAS: + b = tl.load(BIAS + offs_n, mask=offs_n < K_out, other=0.0) + acc += b[None, :] + + if ACT_TYPE == 1: + acc = tl.maximum(acc, 0) + elif ACT_TYPE == 2: + acc = tl.minimum(tl.maximum(acc, 0), 6) + elif ACT_TYPE == 3: + acc = ( + 0.5 * acc * (1.0 + _tanh(0.7978845608 * (acc + 0.044715 * acc * acc * acc))) + ) + + tl.store( + Y_ptrs, + acc, + mask=(n_valid & (p_idx < P) & (q_idx < Q) & kout_mask[None, :]), + ) diff --git a/aiter/ops/triton/_triton_kernels/conv/helpers.py b/aiter/ops/triton/_triton_kernels/conv/helpers.py new file mode 100644 index 0000000000..e80f81e052 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/helpers.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import triton +import triton.language as tl + + +@triton.jit +def _tanh(x): + x = tl.minimum(tl.maximum(x, -10.0), 10.0) + e2x = tl.exp(2 * x) + return (e2x - 1) / (e2x + 1) + + +# ======================================================================== +# AUTOTUNE CONFIGS — centralized to avoid duplication and cross-imports +# ======================================================================== + +# -- 1x1 kernel -- +AUTOTUNE_1x1_CONFIGS = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), +] + +# -- General conv kernel -- +AUTOTUNE_GENERAL_CONFIGS = [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), +] + +# -- 3x3 NHWC kernel -- +AUTOTUNE_3x3_NHWC_CONFIGS = [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 32, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), +] + +# -- 3x3 channel-blocked kernel -- +AUTOTUNE_3x3_CBLOCKED_CONFIGS = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_C": 128, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_C": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), +] + +# -- Winograd F(4,3) GEMM -- +AUTOTUNE_WINO_GEMM_CONFIGS = [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_SIZE_M": 8}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_SIZE_M": 4}, + num_warps=4, + num_stages=1, + ), +] + +# -- Winograd F(4,3) input transform -- +AUTOTUNE_WINO4_INPUT_CONFIGS = [ + triton.Config({"BLOCK_C": 64}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_C": 32}, num_warps=4, num_stages=1), +] + +# -- Winograd F(4,3) output transform -- +AUTOTUNE_WINO4_OUTPUT_CONFIGS = [ + triton.Config({"BLOCK_K": 64}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_K": 128}, num_warps=4, num_stages=1), +] + +# -- Winograd F(4,3) fused GEMM+output -- +AUTOTUNE_FUSED_F4X3_CONFIGS = [ + triton.Config( + {"BLOCK_T": 16, "BLOCK_K": 64, "BLOCK_C": 64}, num_warps=4, num_stages=1 + ), + triton.Config( + {"BLOCK_T": 16, "BLOCK_K": 128, "BLOCK_C": 64}, num_warps=8, num_stages=1 + ), + triton.Config( + {"BLOCK_T": 32, "BLOCK_K": 64, "BLOCK_C": 64}, num_warps=4, num_stages=1 + ), +] diff --git a/aiter/ops/triton/conv/DESIGN.md b/aiter/ops/triton/conv/DESIGN.md new file mode 100644 index 0000000000..cdab45ede4 --- /dev/null +++ b/aiter/ops/triton/conv/DESIGN.md @@ -0,0 +1,707 @@ +# Design + +This document explains how `aiter.ops.triton.conv` is structured, what each kernel does +and why, the math behind the Winograd path, the heuristic that picks between +methods, the memory layouts and repacking that connect them, the numerical +tolerance model used by the test suite, and how to add a new kernel. + +> **Audience.** Anyone planning to extend the library, debug a performance +> regression, or evaluate whether to take a dependency on it. + +--- + +## 1. Goals and non-goals + +**In scope.** + +- Forward 2-D convolution on AMD ROCm/RDNA (WMMA-capable). +- `fp16` and `bf16` inputs; configurable output dtype. +- NCHW and NHWC input layouts. +- Optional fused bias add and activation (`relu`, `relu6`, `gelu`). +- Drop-in replacement for `nn.Conv2d` for inference. + +**Out of scope (today).** + +- **Backward / training.** Inference only. No grad support. +- **`groups > 1`** (depthwise / grouped). Detected at the example layer and + left to PyTorch/MIOpen. +- **`padding_mode != "zeros"`** (reflect / replicate / circular). Same fall-back. +- **fp32 / fp8 inputs not supported.** +- Tuning, autotune configs, and the `_select_3x3_method` + routing table are RDNA4-specific for now. + +--- + +## 2. Architecture at a glance + +```mermaid +flowchart TD + A([conv2d - user entry]):::entry + + A --> NCHW[layout = nchw
conv2d_nchw]:::router + A --> NHWC[layout = nhwc
conv2d_nhwc]:::router + + NCHW --> N1[1x1]:::shape + NCHW --> N3[3x3]:::shape + NCHW --> NG[other]:::shape + + NHWC1[1x1]:::shape + NHWC3[3x3]:::shape + NHWCG[other]:::shape + NHWC --> NHWC1 + NHWC --> NHWC3 + NHWC --> NHWCG + + N1 --> K1[/_conv2d_1x1/]:::kernel + N3 --> S1{_select_3x3_method}:::sel + NG --> KG[/_conv2d_general/]:::kernel + + NHWC1 --> K1n[/_conv2d_1x1
NHWC/]:::kernel + NHWC3 --> S2{_select_3x3_method}:::sel + NHWCG --> KGn[/_conv2d_general/]:::kernel + + S1 --> P1[repack
NCHW → NCHWc]:::pack + S1 --> KWF[/winograd_f4x3/]:::wino + S1 --> P3[repack
NCHW → NCHWc]:::pack + + P1 --> KCB[/cblocked/]:::kernel + P3 --> KWFC[/winograd_f4x3_cblocked/]:::wino + + S2 --> KNH[/nhwc_3x3/]:::kernel + S2 --> KWFN[/winograd_f4x3
NHWC/]:::wino + + classDef entry fill:#1e3a5f,stroke:#4a90e2,stroke-width:2px,color:#fff + classDef router fill:#2c5282,stroke:#63b3ed,color:#fff + classDef shape fill:#4a5568,stroke:#a0aec0,color:#fff + classDef sel fill:#744210,stroke:#f6ad55,color:#fff + classDef pack fill:#7b341e,stroke:#fc8181,color:#fff + classDef kernel fill:#22543d,stroke:#68d391,color:#fff + classDef wino fill:#553c9a,stroke:#b794f4,color:#fff +``` + +The Winograd kernels (`winograd_f4x3*`) themselves are a 3-stage pipeline: + +```mermaid +flowchart LR + X([X
NCHW or NHWC]):::data + X --> IT[input transform
V = BᵀXB]:::stage + IT --> V([V : 36 × T × C_pad]):::data + + V --> GE[batched GEMM
36 of T × K_out
reducing over C_pad]:::stage + GE --> M([M : 36 × T × K_out]):::data + M --> OT[output transform
Y = AᵀMA]:::stage + OT --> Y([Y
NCHW or NHWC]):::data + + V -.fused path.-> FU[fused GEMM + output
V → Y, skips M]:::stage + FU -.-> Y + + classDef data fill:#1e3a5f,stroke:#4a90e2,color:#fff + classDef stage fill:#553c9a,stroke:#b794f4,color:#fff +``` + +The fused kernel (`_winograd_f4x3_fused_gemm_output_kernel`) reads `V` and the +transformed filter `U` directly, accumulates the column transform's 24 partial +outputs in registers across the channel-tile loop, applies the final row +transform once at the end, and writes `Y` without ever materializing the +intermediate `M[36, T, K_out]` tensor. + +**The fused path is opt-in, not router-selected.** `_select_3x3_method` only +ever returns `winograd_f4x3` or `winograd_f4x3_cblocked` — both the standard +3-kernel pipeline. The fused variant is reachable explicitly via +`--method winograd_f4x3_fused` in the bench tool (`op_tests/op_benchmarks/triton/bench_conv2d.py`) +or by importing `conv2d_winograd_f4x3_fused` directly. It wins on a narrow +band of small/medium VAE shapes but hits register pressure at the larger +ResNet 3×3 tile shapes, so it isn't worth a special-case branch in the +production router. + +### Layered Python modules + +``` +aiter/ops/triton/conv/ + __init__.py empty marker (consumers import directly from conv2d.py) + conv2d.py public functions + smart routing in conv2d_nchw / conv2d_nhwc + _launch.py grid setup, _select_3x3_method, dtype mapping + _prepack.py weight/input repacks + LRU caches + _utils.py shape math, _is_*_conv predicates, tolerance model, activation + +aiter/ops/triton/_triton_kernels/conv/ + __init__.py empty marker (kernels are imported by full path) + helpers.py shared autotune config lists (consumed by @triton.autotune + in each kernel file) + _tanh helper + conv_1x1.py 1×1 GEMM kernel (NCHW + NHWC via LAYOUT constexpr) + conv_3x3.py 3×3 NHWC kernel + 3×3 cblocked (NCHWc) kernel + conv_general.py K-major reduction with on-the-fly (c, r, s) decoding + conv_3x3_winograd_f4x3.py + 5 kernels: input transform (NCHW & NHWC), cblocked input + transform, batched GEMM, output transform, fused GEMM+output +``` + +The split mirrors *responsibility*, not LOC: `conv2d.py` is "what does the user +ask for", `_launch.py` is "how do we set up the grid", `_prepack.py` is "what +shape does the kernel actually want to read", and the `_triton_kernels/conv/` folder is +"the math". `_utils.py` sits underneath all four as a shared helper layer +(shape math, eligibility predicates, tolerance model, activation enum). + +--- + +## 3. Public API + +The headline is **`conv2d`**: + +```python +def conv2d(x, w_oihw, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), + activation="none", out_dtype=None, layout="nchw"): +``` + +| Argument | Meaning | +|---|---| +| `x` | Input. NCHW shape, optionally with `channels_last` strides for NHWC mode. | +| `w_oihw` | Weight in PyTorch's canonical `[K_out, C, R, S]` layout. | +| `bias` | Optional 1-D bias of length `K_out`, cast to fp32 once at entry. | +| `stride`, `padding`, `dilation` | Standard `Conv2d` semantics; tuples of ints. | +| `activation` | `"none" / "relu" / "relu6" / "gelu"` — fused into the kernel epilogue. | +| `out_dtype` | `None` (default — match input dtype, mirrors `nn.Conv2d`) or one of `torch.float16` / `torch.bfloat16` to override. The fp32 accumulator is downcast at store. | +| `layout` | `"nchw"` or `"nhwc"` (case-insensitive — passed through `.lower()`). Selects which top-level routing function runs. | + +The semi-public method-specific functions (`conv2d_nchw_cblocked`, +`conv2d_winograd_f4x3_cblocked`, …) take an internal `block_k=64` channel-pack +tile size used by the prepack caches. It is intentionally not surfaced on the +public `conv2d` — every autotune config in `helpers.py` assumes 64, so the +parameter has no good user story. + +Both `x.dtype` and (when explicitly passed) `out_dtype` are validated at +entry: anything other than `torch.float16` / `torch.bfloat16` raises +`ValueError`. `w_oihw.dtype` is trusted to the kernel. + +Everything else (`conv2d_nchw`, `conv2d_1x1`, `conv2d_winograd_f4x3_cblocked`, +…) is **semi-public**: not in `__all__` but importable by name. The benchmark +harness uses these directly to compare methods on the same shape. + +`conv2d.py` also exposes `_last_triton_kernel: Optional[str]` — set by +`conv2d_nchw` / `conv2d_nhwc` after each routing decision, read by the bench +to label rows in the per-layer table. + +--- + +## 4. Smart routing (`_select_3x3_method`) + +Only 3×3 has a real choice — 1×1 always uses the 1×1 kernel, 5×5/7×7 always +fall to `general`. For 3×3, `_launch.py:_select_3x3_method` decides: + +```python +def _select_3x3_method(N, C, H, W, K_out, stride, dilation): + # 1) Non-Winograd-eligible (stride>1, dilation>1, or C<4) -> cblocked + if not _is_winograd_eligible(3, 3, stride, dilation, C): + return "cblocked" + # 2) Tile count: F(4,3) emits one 4x4 output tile from each 6x6 input patch, + # overlap = 2 (input tiles step by 4 with 6-wide windows). + P, Q = _out_hw(H, W, 3, 3, stride, (1,1), dilation) + tile_H, tile_W = (P + 3) // 4, (Q + 3) // 4 + T = N * tile_H * tile_W + # 3) Winograd only wins when both C and K are large enough + # AND there are enough tiles to amortize transform overhead. + if C >= 512 and K_out >= 512 and T >= 98: + return "winograd_f4x3_cblocked" if T >= 392 else "winograd_f4x3" + return "cblocked" +``` + +The thresholds (`C/K ≥ 512`, `T ≥ 98`, `T ≥ 392`) come from a sweep on +RDNA4 — see the comment block in `_launch.py`. Two implications worth knowing: + +1. **The router uses padding `(1,1)` regardless of the caller's padding.** The + tile count it computes is approximate by design: the router only needs to + pick a method, not the exact `(P, Q)`. The actual padding flows through the + chosen kernel unchanged. (This is intentional — see + `memory/project_select_3x3_padding_heuristic.md`.) +2. **Below `C/K = 512`, the cblocked direct kernel is faster than Winograd.** + The Winograd transform overhead (read 6×6 patches, two 6×6 matrix + multiplies by Bᵀ and B, ~70 fp32 FMAs per (tile, channel) for the input + transform alone) dominates the FMA savings until the GEMM body is large + enough. + +Stride > 1 or dilation > 1 disqualifies Winograd entirely — the F(4,3) +algorithm is built around a 6-wide window stepping by 4, and that geometry +breaks under non-unit stride. + +NHWC reuses the same router but collapses its output to two destinations: +the standard NHWC 3×3 kernel, or the **non-cblocked** Winograd path with its +input transform kernel switched to NHWC reading via the `LAYOUT=1` constexpr. +Even when the router returns `winograd_f4x3_cblocked`, NHWC mode falls back to +plain `conv2d_winograd_f4x3` because there is no NHWC cblocked input repack — +cblocked is an NCHW-only optimization. The GEMM and output transform are +layout-independent because they operate on `V[36, T, C_pad]` and emit directly +into the user's chosen output layout. + +--- + +## 5. Per-kernel deep-dive + +### 5.0 A platform note on `num_stages` + +Every autotune config in `helpers.py` pins **`num_stages=1`**. That is +deliberate. + +The `num_stages > 1` Triton knob is meant to lower to a software-pipelined +loop: the compiler hoists global → LDS loads for iteration *i+1* across the +matmul of iteration *i* and rotates through `num_stages` LDS buffers so +memory and compute overlap. On RDNA there is no asynchronous global-to-LDS +copy instruction for the compiler to schedule ahead — global loads are +issued and waited on in-order through `s_waitcnt`. Triton's RDNA backend +therefore does not currently produce a pipelined loop for `num_stages > 1`; +the extra LDS buffers are allocated but the load hoisting never +materializes. Empirically `num_stages > 1` on RDNA either matches or loses +to `num_stages=1` (more LDS pressure, no overlap gained), so we don't sweep +it. + +This is why every tile choice in the sections below is paired with +`num_warps` only — there is no stage axis to tune. Memory-latency hiding on +RDNA comes from the hardware: wave-level scheduling across CUs and +interleaving arithmetic between a load's issue and its `s_waitcnt`, rather +than from compiler-issued software pipelining. + +### 5.1 `_conv2d_1x1_kernel` + +A 1×1 conv is mathematically a GEMM: + +``` +y[n,k,p,q] = Σ_c x[n,c,p,q] · w[k,c] (stride/padding still apply to (p,q)) +``` + +The kernel fuses this with the index unwrap `m → (n, p, q)` and a +`LAYOUT` constexpr (0 = NCHW, 1 = NHWC) that selects the strides on read/write. +Highlights: + +- **Tile shape:** `BLOCK_M × BLOCK_N × BLOCK_K`, autotuned over the + `AUTOTUNE_1x1_CONFIGS` grid in `helpers.py`. On RDNA4 the winners cluster + around `BM=128, BN=128`, `BK ∈ {32, 64}`, 8 warps (`num_stages=1` — + see 5.0). +- **L2 cache swizzle.** Tiles are reordered into super-groups of + `GROUP_SIZE_M` along the `M` axis so each weight (`N`-axis) tile is reused + across `GROUP_SIZE_M` consecutive workgroups before moving on — the same + weight columns stay hot in L2 for the duration of a group. Standard MM + trick. +- **Stride > 1 / non-zero padding** are handled by mapping the tile's + `(p, q)` back to `(h, w) = (sh·p − ph, sw·q − pw)` and checking bounds. The + hot path (stride 1, padding 0) and the slow path go through the same code + but the mask collapses to "true" for the hot path so the compiler should + hoist it. + +### 5.2 `_conv2d_3x3_nhwc_kernel` + +NHWC-native 3×3 with **K-major weight layout** `W3[K_out, 9, C_pad]`: + +- For each `(p, q)` output column the kernel walks the 9 spatial taps, loads + the corresponding `[BLOCK_M, BLOCK_C]` slice of input and the matching + `[BLOCK_C, BLOCK_N]` slice of W3, and accumulates + `tl.dot(...)` into `[BLOCK_M, BLOCK_N]`. +- Channel padding `C_pad` is rounded up to a multiple of `BLOCK_C` and the + trailing weight lanes are pre-zeroed in the prepack. The spatial validity + mask is hoisted out of the inner C loop; the only per-iteration mask is + the channel-bound check (`c_offs < C`). +- Hardcoded `stride_x_c=1` (channels are contiguous in NHWC) means the + load addressing collapses to a single base + linear stride; the compiler + emits one vectorized load per row. +- **Same L2 swizzle as 5.1**: workgroups are reordered into super-groups of + `GROUP_SIZE_M` along the `M` axis (autotuned to 4 or 8 in + `AUTOTUNE_3x3_NHWC_CONFIGS`) so each weight (`N`-axis) tile stays hot in + L2 across `GROUP_SIZE_M` consecutive workgroups. + +### 5.3 `_conv2d_3x3_cblocked_kernel` + +Same 3×3 math and same `M = N·P·Q` GEMM decomposition as the NHWC kernel, +but with **NCHWc input layout** `X[N, C_blocks, H, W_in, Cb]` (`Cb=64` by +default; the kernel takes it as a constexpr). Why two kernels? + +- The NHWC kernel keeps the user's input layout intact (no input repack) + and gets coalesced channel loads "for free" because channels are already + the inner contiguous axis. +- The user's NCHW input has channels on the *outer* stride (`H·W` apart), + so a direct read from NCHW would scatter the channel loads. cblocked + **repacks the input once** into `[N, C_blocks, H, W, Cb]` to restore a + small inner contiguous channel axis (`Cb` wide) that the same `M = N·P·Q` + GEMM walk can read coalesced. Output is plain NCHW so there's no output + repack. + +Both kernels share the GEMM order; the difference is purely in the input +addressing. + +Highlights: + +- **`stride_c_local=1` is hardcoded** in the load path (`c_local[None, :]` + is added directly to the pointer without a stride multiplier), so the + compiler emits one vectorized load per channel chunk — structurally the + same hardcoded-stride trick as `stride_x_c=1` in the NHWC kernel. +- **Channel addressing.** Inside the C loop, `c_offs // Cb` selects the + block index and `c_offs % Cb` the offset within the block. This stays + coalesced **only when `BLOCK_C ≤ Cb`** — at the boundary, `cblock_idx` + jumps and the load address discontinues. This is an implicit constraint + on `AUTOTUNE_3x3_CBLOCKED_CONFIGS`; new configs must respect it. +- **Same L2 swizzle as 5.1/5.2** — workgroups are reordered into + super-groups of `GROUP_SIZE_M` along the `M` axis so each weight + (`N`-axis) tile stays hot in L2 across `GROUP_SIZE_M` consecutive + workgroups. + +The cblocked kernel is the default for NCHW 3×3 below the Winograd threshold +because, even after paying the input repack cost, it consistently beats both +"general with K-major" and the NHWC kernel with implicit transposes. The +kernel+repack number in the bench includes the per-batch input repack; +the user-facing `conv2d_nchw` calls `get_or_make_input_pack_cblocked`, which +caches by `(storage_ptr, shape, dtype, ...)` so back-to-back calls with the +same input tensor only repack once. + +### 5.4 `_conv2d_general_kernel` + +The fallback for everything that isn't 1×1, 3×3, or Winograd-eligible — +i.e. any kernel size other than 1×1 or 3×3 (5×5, 7×7, dilated 5×5, etc.). +Note: dilated 3×3 still routes to cblocked, not here. Strategy: + +- Pack weights once into K-major `W_K[K_out, K_pad]` where + `K_pad = pad(C·R·S, block_k)`. The trailing `K_pad − C·R·S` weight lanes + are zero so any reduction step in the tail contributes 0. +- Tile the output as a GEMM over `M = N·P·Q` rows × `N = K_out` cols × + `K = K_pad` reduction. +- **No im2col buffer.** The kernel walks `K_pad` and decodes + `k → (c, r, s)` on the fly: `c = k // (R·S)`, `rs = k mod (R·S)`, + `r = rs // S`, `s = rs mod S`. Each `(c, r, s)` triple yields the input + coordinate `(h, w) = (sh·p − ph + dh·r, sw·q − pw + dw·s)` and a bounds + mask `(oh, ow) ∈ [0,H)×[0,W)` **plus a `c < C` channel-bound check** + (the K_pad tail decodes `c ≥ C`, so the input-side mask prevents OOB + reads — the zero-padded weight side just guarantees those phantom + contributions are 0). fp32 accumulator, downcast at store. +- **`LAYOUT` constexpr** (0=NCHW, 1=NHWC) selects input/output strides; the + same kernel serves both layouts. The router calls it with `layout="nhwc"` + for non-1×1/3×3 NHWC shapes. +- **Same L2 swizzle** as 5.1/5.2/5.3 — `GROUP_SIZE_M` super-grouping along M. + +### 5.5 The Winograd path — five kernels + +Winograd F(4×4, 3×3) reduces the **144** multiplies per 4×4 output tile of +a direct 3×3 (16 outputs × 9 taps) to 36 multiplies *total* (one batched +36-way GEMM) — **4× fewer multiplies** in the inner GEMM, modulo +floating-point ordering. We pay for it in: (a) input/output transforms +that touch every element with small constants, (b) increased numerical +sensitivity in fp16/bf16. The math is in section 6; here are the five kernels: + +| Kernel | What it does | Output shape | +|---|---|---| +| `_winograd_f4x3_input_transform_kernel` | One 6×6 patch per `(t, channel-tile)`, computes `BᵀXB` in fp32, stores to V | `V[36, T, C_pad]` | +| `_winograd_f4x3_cblocked_input_transform_kernel` | Same, but reads from NCHWc `[N, C_blocks, H, W, Cb]` for coalesced loads | `V[36, T, C_pad]` | +| `_winograd_f4x3_batched_gemm_kernel` | 36 independent GEMMs `T × C_pad × K_out`, output `M[36, T, K_out]` | `M[36, T, K_out]` | +| `_winograd_f4x3_output_transform_kernel` | One 6×6 `M` slice per `(t, k)`, computes `AᵀMA` (4×4), bias + activation, scatter to NCHW or NHWC output | `Y[N, K_out, P, Q]` | +| `_winograd_f4x3_fused_gemm_output_kernel` | Streams the 36 GEMMs in 6 column-groups; applies the column transform `Aᵀ · M[:,col]` on-the-fly per group to accumulate 24 `s` values (skipping materialization of `M`), then applies the row transform once at the end. Replaces kernels 3 + 4 with a single launch at the cost of higher per-CTA register pressure (24 `BLOCK_T × BLOCK_K` fp32 accumulators). | `Y[N, K_out, P, Q]` | + +`conv2d_winograd_f4x3` and `conv2d_winograd_f4x3_cblocked` use the +3-kernel pipeline (input transform + GEMM + output transform). +`conv2d_winograd_f4x3_fused` uses the fused 2-kernel pipeline (input +transform + fused GEMM/output). The fused variant trades a smaller +intermediate (skips writing/reading `M[36, T, K_out]`) for higher per-CTA +register pressure; on RDNA4 it's slightly faster on large `(T, K_out)` and +slightly slower on small ones, which is why both exist in the registry but +the auto-router picks only the unfused variants. + +The **filter transform** `U = G g Gᵀ` runs once on the host in +`prepack_winograd_filter_f4x3` (in fp32), then is cast to the activation +dtype and stored as `U[36, K_out, C_pad]`. The transform is per-weight, not +per-call, so it amortizes across every forward pass. + +--- + +## 6. Winograd F(4×4, 3×3) — full derivation + +Notation: `g` = 3×3 filter, `d` = 6×6 input tile, `Y` = 4×4 output tile. + +### 6.1 Why Winograd + +A direct 3×3 conv produces 4 output values from 6 input values along one +axis using `4 × 3 = 12` multiplies. Winograd's minimal 1-D algorithm +F(4, 3) needs only **6** multiplies for the same 4 outputs. Tensoring along +both spatial dims gives F(4×4, 3×3): `4×4 = 16` outputs from a 6×6 input, +in `6×6 = 36` element-wise products instead of `16 × 9 = 144`. + +Per output element: 144 / 16 = 9 muls direct, 36 / 16 = 2.25 muls Winograd +→ **4× fewer multiplies** in the inner tensor product. The headline 4× +ignores the input/output transforms (`BᵀdB`, `AᵀMA`), which aren't free +and are why the realized speedup is smaller than 4×. + +**Reporting convention.** Both the bench harness and the published charts +divide measured wall time by the *direct-convolution* FLOP count +`2·N·K·C·R·S·H·W`, regardless of which algorithm actually ran. This is +the universal convention — it gives a single yardstick on which different +algorithms (direct, im2col-GEMM, Winograd, FFT) are comparable, since +literal hardware-MAC counts are not. The implication: a Winograd row's +TFLOPS overstates the literal MACs/sec the hardware executes by ~4× for +F(4×4, 3×3) and ~2.25× for F(2×2, 3×3). MIOpen's Winograd solvers +(`ConvBinWinogradRxSf3x2`, `ConvWinoFuryRxS<2-3>`) are reported the same +way, so Triton-vs-MIOpen comparisons remain apples-to-apples. + +### 6.2 The transform matrices + +For F(4, 3) the standard Cook-Toom matrices (with +`{0, ±1, ±2, ±½, ∞}` interpolation points) are: + +``` + ┌ 6 0 0 ┐ ┌ 4 0 -5 0 1 0 ┐ + │ -4 -4 -4 │ │ 0 -4 -4 1 1 0 │ +G = │ -4 4 -4 │ × (1/24) Bᵀ =│ 0 4 -4 -1 1 0 │ + │ 1 2 4 │ │ 0 -2 -1 2 1 0 │ + │ 1 -2 4 │ │ 0 2 -1 -2 1 0 │ + └ 0 0 24┘ └ 0 4 0 -5 0 1 │ (×1, no scale) + + ┌ 1 1 1 1 1 0 ┐ + Aᵀ = │ 0 1 -1 2 -2 0 │ + │ 0 1 1 4 4 0 │ + └ 0 1 -1 8 -8 1 ┘ (4×6, scale 1) +``` + +(`_prepack.py:prepack_winograd_filter_f4x3` writes G with the `1/24` already +absorbed: rows scaled by `1/4, -1/6, -1/6, 1/24, 1/24, 1` so a host-side +`G g Gᵀ` produces `U` directly.) + +The math is then: + +``` + U = G g Gᵀ (6×6, computed once per weight, host-side, fp32) + V = Bᵀ d B (6×6, computed per input tile, on-chip, fp32) + M = U ⊙ V (6×6, element-wise) + Y = Aᵀ M A (4×4, output tile) +``` + +**Batched.** With `T = N · ⌈P/4⌉ · ⌈Q/4⌉` tiles and `C` channels per tile, +the per-tile element-wise product becomes a tile-batched GEMM: + +``` + M[α, t, k] = Σ_c U[α, k, c] · V[α, t, c] α ∈ [0, 36) +``` + +— **36 independent GEMMs** of shape `(T, C) × (C, K_out)`. This is what the +batched GEMM kernel does: `tl.program_id(1)` selects one of the 36 `α` +slices and `tl.program_id(0)` selects a `(t, k)` tile (with the standard +`GROUP_SIZE_M` L2 swizzle), then runs a normal Triton MM over `c`. + +### 6.3 What the kernels actually look like + +The transform matrices are **embedded in the JIT'd kernel as Python +constants** (not loaded from memory), so the compiler can constant-fold and +strength-reduce. The input transform `V = Bᵀ d B` is computed as the +standard two-pass factorization — column transform first (`t = Bᵀ d`), +then row transform (`V = t B`): + +``` + # column transform (t = Bᵀ d), 36 values, each 4–6 fp32 ops + t00 = 4*d00 - 5*d20 + d40 + t01 = 4*d01 - 5*d21 + d41 + ... + # row transform (V = t B), 36 values, each 4–6 fp32 ops + v00 = 4*t00 - 5*t02 + t04 + v01 = -4*t01 - 4*t02 + t03 + t04 + ... +``` + +Each row of `t` is reused across all 6 columns of `V` in that row, so the +two-pass form is ~72 small fp32 expressions total (36 `t` + 36 `v`) versus +the much larger sum-of-products that a fully-expanded `Bᵀ d B` would +unroll to. Input is loaded once per `(t, channel-tile)` as a 6×6 spatial +block over `BLOCK_C` channels, cast to fp32, transformed, stored to +`V[36, T, C_pad]`. + +The output transform is symmetric: read `M[36, t, k]` (6×6 per tile), +compute `AᵀMA` (4×4), add bias, apply activation, scatter to `Y` (NCHW or +NHWC, picked via a `LAYOUT` constexpr). The scatter has to write 16 values +per tile to a strided output; this is where the kernel does the most index +arithmetic. + +The **fused GEMM+output** kernel processes `M` column-by-column over its 6 +columns (not the 4 output columns). For each `col ∈ [0, 6)`, it runs 6 +mini-GEMMs (one per α-row at that column) and applies the column transform +`Aᵀ` on-the-fly to compress the 6 GEMM results into 4 `s` values for that +column. These accumulate into 24 (= 4 × 6) `BLOCK_T × BLOCK_K` fp32 +accumulators across the channel-tile loop. After the channel loop, a +single row transform produces the 4×4 output. This skips the round-trip +of `M` through global memory but needs many more live registers per +program — hence the two variants. + +### 6.4 Numerical caveat + +The output transform `AᵀMA` has elements as large as `±8`. Rounding errors +in `M` from the GEMM accumulation compound through the transform on both +sides of `AᵀMA`. Row 3 of `Aᵀ` (`[0, 1, -1, 8, -8, 1]`) has L1 norm `19`; +applied on both sides for the worst element `Y[3, 3]`, this bounds the +error amplification at `19 × 19 = 361×`. On bf16 (mantissa = 7 bits, +ε ≈ 2⁻⁷) that's enough to typically drop **3–4 bits of precision** vs the +direct conv. Two consequences: + +1. The accumulators are **fp32 throughout** (Triton's `tl.dot` produces + fp32; we keep it fp32 through the output transform and only downcast at + store). This is non-negotiable. +2. We disable Winograd for `C < 4` in `_is_winograd_eligible`, because at + `C=3` (RGB conv1) there aren't enough sum terms to absorb the + amplified rounding within the test tolerance. + +The test harness applies a **6× tolerance multiplier** to Winograd kernels +(`_winograd_tolerances`); see §8. + +--- + +## 7. Memory layouts and repacking + +There are four "shapes" the kernels actually consume, plus the user's NCHW +input: + +| Layout | Kernel | Where the repack happens | Cached? | +|---|---|---|---| +| `[K_out, K_pad]` (K-major weight) | `general` | `_prepack.py:prepack_oihw_to_kmajor` | LRU 256, `_PACK_CACHE` | +| `[K_out, 9, C_pad]` (3×3 weight) | `3x3_nhwc`, `3x3_cblocked` | `prepack_oihw_to_3x3` | LRU 256, `_PACK_CACHE_3x3` | +| `[N, C_blocks, H, W, Cb]` (NCHWc input, `Cb=64`) | `3x3_cblocked`, `winograd_f4x3_cblocked` (input transform) | `prepack_nchw_to_cblocked` | **Single-entry dict** by design | +| `[36, K_out, C_pad]` (Winograd weight) | `winograd_f4x3_*` | `prepack_winograd_filter_f4x3` | LRU 256, `_PACK_CACHE_WINOGRAD_F4X3` | + +### Why three weight caches are LRU but the input cache is single-entry + +Weights are reused every forward pass. A 53-layer ResNet has 53 unique +weight tensors; with a 256-entry LRU we keep them all warm and the second +forward pass through the model has zero repack cost. + +Input activations are *unique intermediate tensors* — the output of layer N +is consumed exactly once, by layer N+1, and never seen again. Caching them +LRU would just bloat memory; a single-entry dict is enough to dedup +repeated calls with the same tensor (which the bench loop *does* do — it +runs each layer with `warmup=15, rep=50` over the same input to get a +stable timing). The `.clear()` in the bench's per-call setup deliberately +models the per-batch repack cost in real inference, which is reported as +the "kernel + repack" row. + +### Cache key safety + +Each cache key includes `(storage_ptr, shape, dtype, block_k, _version)`. +The cache also stores a strong reference to the source tensor in the value, +so its storage cannot be freed and re-used by a different tensor while the +entry lives — this prevents `storage_ptr` collisions from producing a +false hit. The single-entry input cache also re-checks the source pointer +on hit (defense in depth). + +`AITER_TRITON_CONV_PACK_CACHE_SIZE=N` (env var) overrides the LRU bound. +Default 256 is enough for the models exercised by the bench harness without +eviction; larger models need the env override. + +--- + +## 8. Numerical model and tolerances + +`_utils.py:dynamic_conv_tolerances` is the formula used by the test +harness: + +```python +eps = {fp16: 2**-10, bf16: 2**-7, fp32: 2**-23}[dtype] +rtol = 6e-3 if K_red < 1024 else 8e-3 if K_red < 4096 else 1.2e-2 +atol = max(eps * 8, 10.0 * eps * sqrt(K_red)) +``` + +Where `K_red = C · R · S`. Two facts behind this: + +1. **Multiplication of two ε-rounded inputs has relative error ε.** +2. **Summing N independent rounding errors in fp32 grows as `ε · √N`** (random + walk), not `ε · N` (worst-case). Real kernels are not adversarial: the + `√N` bound matches what we observe across thousands of fuzzer shapes. + +The `10×` multiplier covers worst-case ordering differences between our +Triton accumulation order and PyTorch's (different tile shape → different +sum order → different rounding). + +`_winograd_tolerances` then bumps `rtol` by **6×** and `atol` by +`max(6×, 0.6)` for any Winograd kernel — the `0.6` floor catches +small-`K_red` cases where the √N-scaled base atol is too tight to absorb +the F(4,3) amplification. The 6× factor is empirical: the analytical +bound from §6.4 is 361× worst-case, but the *typical* amplification on +natural images is much smaller because the worst-case tile structure is +rare. 6× catches the long tail without flagging healthy fp16 ordering +noise. + +If a new kernel needs the Winograd bump, mark it `is_winograd=True` in +`op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`); the +`_get_tolerances` dispatch picks the right tolerance automatically. + +--- + +## 9. Method registry + +`op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`) is the **single source of truth** +for kernel dispatch in the test harness. Adding a new method takes one entry: + +```python +METHOD_REGISTRY = { + "default": MethodEntry(conv2d_nchw, None, False, "", "default"), + "cblocked": MethodEntry(conv2d_nchw_cblocked, _3x3_guard, False, "[cblocked]", "cblocked"), + "winograd_f4x3": MethodEntry(conv2d_winograd_f4x3, _wino_guard, True, "[winograd_f4x3]", "WF(4,3)"), + "winograd_f4x3_fused": MethodEntry(conv2d_winograd_f4x3_fused, _wino_guard, True, "[wino_f4x3_fused]", "WF4fused"), + "winograd_f4x3_cblocked": MethodEntry(conv2d_winograd_f4x3_cblocked, _wino_guard, True, "[winograd_f4x3_cblocked]", "WF4cb"), +} +``` + +| Field | Purpose | +|---|---| +| `kernel_fn` | The public `conv2d_*` wrapper (not the raw kernel) — handles repack and grid setup, or routes to a specialized variant (`conv2d_nchw` does smart routing). | +| `guard_fn(R, S, stride, dilation, C)` | Returns `True` if the method is applicable. `None` means "always". | +| `is_winograd` | Selects the 6× tolerance bump in `_helpers._get_tolerances`. | +| `bench_tag` | Suffix added to the per-test result name, e.g. `"[cblocked]"`. | +| `short_name` | Reserved for future bench output; not used by `bench_conv2d.py` today. | + +Wiring downstream of the registry — the parametrized pytest tests, the +tolerance dispatch — is automatic. You do not edit any of those. + +--- + +## 10. Adding a new kernel — checklist + +Concretely, to add (say) a `winograd_f6x3` variant: + +1. **Implement the kernel.** New file under `aiter/ops/triton/_triton_kernels/conv/`. + Use the autotune config sets in `helpers.py` if a standard one fits; only + write a new one if a different shape really needs it. +2. **Add a launch wrapper** in `_launch.py` (`_launch_winograd_f6x3`) that + sets up the grid and turns Python ints into Triton constexprs. +3. **Add a public function** in `conv2d.py` + (`conv2d_winograd_f6x3(...)`). Compute output shape, allocate the + output tensor, fetch any prepacked weights from `_prepack.py`, call + the launch wrapper. Consumers import it directly from this module — + no `__init__.py` re-export needed. +4. **Add a prepack** in `_prepack.py` if the kernel needs a different + weight layout. Use `_LRUPackCache` with a key that includes + `(storage_ptr, shape, dtype, block_k, _version)` and store a strong + ref to the source tensor in the cached value. For input-side prepacks + (rare — only `_PACK_CACHE_CBLOCKED` does this today), use a + single-entry plain dict cleared per-call instead — see §7. +5. **Register** in `op_tests/triton_tests/conv/_helpers.py` (`METHOD_REGISTRY`). Set + `is_winograd=True` if the kernel uses Winograd-style transforms + (you'll need the 6× tolerance bump). If the kernel uses a different + Winograd tile (e.g. F(6,3)), also add a new `variant=` branch in + `_utils._winograd_tolerances` and route to it from + `_helpers._get_tolerances` — F(4,3)'s 6× bump + is calibrated specifically for that tile size. Also add the kernel to + `op_tests/op_benchmarks/triton/bench_conv2d.py`'s `METHODS` dict so + `--method ` works. +6. **Optionally route** from `_select_3x3_method` if the new kernel + should be auto-selected. Update the heuristic comment block with the + shape range where it wins. If you do route from `_select_3x3_method`, + also add a `_last_triton_kernel = "..."` line in the matching branch + of `conv2d_nchw` / `conv2d_nhwc` so the bench labels the row + correctly. +7. **Run the suite** — `pytest op_tests/triton_tests/conv/` parametrizes + `test_no_bias` and `test_activations` over every entry in + `METHOD_REGISTRY`, so the new method gets correctness coverage + automatically. Bench it via + `python -m op_tests.op_benchmarks.triton.bench_conv2d --method ...`. + +--- + +## 11. Known limitations and future work + +- **`groups > 1`.** Kernels are `groups=1` only. Implementation entry + points: a new `aiter/ops/triton/_triton_kernels/conv/conv_depthwise.py` + for the depthwise case, K-major-per-group prepack in `_prepack.py`, + route from `_select_3x3_method`. +- **`padding_mode != "zeros"`.** Reflect / replicate / circular padding + is unsupported. Adding it is mostly a per-kernel change to the + bounds-mask logic — replace `mask = (h ≥ 0) & (h < H) & ...` with a + reflected/wrapped index computation. +- **Backward / training.** No gradients. Adding training would mean a + second set of kernels (`grad_input`, `grad_weight`) and is out of + scope for this library's intended use as an inference replacement. diff --git a/aiter/ops/triton/conv/README.md b/aiter/ops/triton/conv/README.md new file mode 100644 index 0000000000..4ba1bc59f1 --- /dev/null +++ b/aiter/ops/triton/conv/README.md @@ -0,0 +1,229 @@ +# conv2d (Triton, AMD ROCm) + +> **`Conv2d` for AMD ROCm — a drop-in replacement for `torch.nn.Conv2d`, +> optimized for AMD RDNA GPUs.** + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) +[![Python](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org) +[![PyTorch](https://img.shields.io/badge/PyTorch-2.9.1-ee4c2c.svg)](https://pytorch.org) +[![ROCm](https://img.shields.io/badge/ROCm-7.2-ED1C24.svg)](https://www.amd.com/en/developer/resources/rocm-hub.html) +[![Triton](https://img.shields.io/badge/Triton-3.7-orange.svg)](https://github.com/triton-lang/triton) + +A hand-written Triton 2-D convolution library optimized for AMD RDNA +GPUs. Five kernel families (1×1, 3×3 cblocked, 3×3 NHWC, Winograd +F(4×4, 3×3), general) behind one shape-driven router and one entry +point. Drop-in for the forward path of `nn.Conv2d`. + +--- + +## Why this op exists + +PyTorch on AMD goes through MIOpen, which ships hand-tuned solvers per +architecture, per dtype, per layout. That works well on the combinations +the solvers were specifically tuned for, but every new dtype × layout × +architecture combination needs its own tuning pass — so coverage is +uneven across the matrix (e.g. on RDNA4 the fp16 path is well-served, +while bf16 falls back to direct/GEMM solvers that are noticeably slower +at large channel counts; most modern checkpoints — LLMs, diffusion VAEs +— ship in bf16). + +This op takes the opposite approach: a single set of Triton kernels +that runs **fp16 and bf16 through the same code path**, supports +**both NCHW and NHWC end-to-end** (NHWC inputs run on an NHWC kernel — +no NHWC↔NCHW conversion), and gets reasonable performance across the +full matrix **without per-architecture hand tuning**. A shape-driven +router picks between five kernel families (1×1, 3×3 cblocked, 3×3 NHWC, +Winograd F(4×4, 3×3), general) so the right kernel runs per layer +automatically. Some kernels do repack inputs/weights into kernel-local +formats (channel-blocked tiles for cblocked, G/Bᵀ transforms for +Winograd) — these packs are LRU-cached so steady-state cost is +negligible. + +--- + +## Performance + +Designed to deliver strong throughput on AMD RDNA4 across both fp16 and +bf16. The bench harness logs which MIOpen solver was selected per layer +and reports aggregate TFLOPS for both backends, so you can verify the +behavior on your stack. + +To measure on your stack: + +```bash +python -m op_tests.op_benchmarks.triton.bench_conv2d \ + --model \ + --dtype \ + [--miopen-solvers] # opt-in; ~60-120s upfront subprocess +``` + +The bench harness produces three box-drawn tables: LAYER-BY-LAYER (per-layer +Triton vs MIOpen TFLOPS, Triton kernel name, optionally MIOpen solver, +kernel+repack column for shapes that prepack), MIOpen SOLVER SUMMARY (only +with `--miopen-solvers`), and OVERALL PERFORMANCE (mean/median/aggregate +TFLOPS, total time, layer wins, correctness). + +> **Note on TFLOPS**: numbers are *direct-convolution-equivalent* throughput +> (the standard convention used by cuDNN, MIOpen, and the Winograd +> literature), applied identically to both backends. Winograd kernels — +> Triton's F(4×4, 3×3) and MIOpen's F(2×2, 3×3) / Fury alike — execute +> fewer literal hardware MACs than this denominator counts (≈4× fewer for +> F(4,3), ≈2.25× for F(2,3)). The comparison is apples-to-apples. + +--- + +## Quick start + +### Use the function directly + +```python +import torch +from aiter.ops.triton.conv.conv2d import conv2d + +x = torch.randn(4, 256, 56, 56, device="cuda", dtype=torch.float16) +w = torch.randn(512, 256, 3, 3, device="cuda", dtype=torch.float16) + +y = conv2d( + x, w, bias=None, + stride=(1, 1), padding=(1, 1), dilation=(1, 1), + activation="relu", # "none" | "relu" | "relu6" | "gelu" + out_dtype=None, # None → match input dtype (default) + layout="nchw", # "nchw" or "nhwc" +) +``` + +A shape-driven router picks one of five kernel families: + +| Family | When it runs | +|---|---| +| 1×1 GEMM | `R==1, S==1` | +| 3×3 cblocked (NCHW) | 3×3, channel-blocked input for coalesced loads | +| 3×3 NHWC | 3×3 with channels-last input — no input repack | +| Winograd F(4×4, 3×3) | 3×3, stride=1, dilation=1, `C ≥ 512`, `K ≥ 512`, enough output tiles | +| General | anything not 1×1 or 3×3 (5×5, 7×7, dilated, strided) | + +### Use as `nn.Conv2d` drop-in + +The kernel families above are functional; wrapping them in an `nn.Module` +(walk a model, swap each `nn.Conv2d` for a Triton-backed module that +calls `conv2d(...)` in its `forward`) works as expected and produces +images visually indistinguishable from the PyTorch / MIOpen reference. + +Pixel-level agreement on FLUX.2-klein-9B (50 diffusion steps, same prompt +and seed under both backends, only VAE convs swapped to Triton): max diff +**6 / 255**, mean diff **0.17 / 255**. See `examples/flux2_inference.py` +to reproduce and inspect the generated images. + +--- + +## Constraints + +- `groups` must equal 1 (depthwise / grouped not yet implemented). +- `padding_mode` must be `"zeros"`. The pad *amount* (`padding=`, e.g. + `(1, 1)` or asymmetric `(0, 2)`) is unrestricted; only the pad *value* + is — `"reflect"`, `"replicate"`, and `"circular"` fall back to PyTorch / + MIOpen. +- Inputs must be `fp16` or `bf16`. +- Forward only (no backward / training). + +--- + +## Reproducing the tests and benchmarks + +Run from the AITER repo root (`/app/aiter` in this tree, or `PYTHONPATH=/app/aiter`). + +### Correctness (CDNA CI runs this) + +```bash +pytest op_tests/triton_tests/conv/ # full matrix, 74 tests +pytest op_tests/triton_tests/conv/ -k "no_bias and fp16_nchw" # subset +pytest op_tests/triton_tests/conv/ -k "test_edge" # one test family +``` + +Tests are parametrized over `(dtype, layout, method)`. Every kernel in +`_helpers.ORDERED_METHODS` is exercised against fp16 and bf16 on NCHW. +NHWC is single-dispatch (only `conv2d_nhwc`), so each NHWC test runs once +per dtype. + +### Benchmark + +Three modes, all in `bench_conv2d.py`. + +**Single shape** (one parseable result line — for ad-hoc measurements): + +```bash +python -m op_tests.op_benchmarks.triton.bench_conv2d \ + --N 1 --C 64 --H 56 --W 56 --K 64 --R 3 --S 3 --pad-h 1 --pad-w 1 +``` + +**Built-in 12-shape sweep** (no model required, three box-drawn tables): + +```bash +python -m op_tests.op_benchmarks.triton.bench_conv2d --dtype fp16 +``` + +**Real-model sweep** (reads shapes from `model_shapes.json`): + +```bash +python -m op_tests.op_benchmarks.triton.bench_conv2d --model resnet50 +python -m op_tests.op_benchmarks.triton.bench_conv2d --model "FLUX.2" --miopen-solvers +``` + +Cross-axis flags: + +``` +--dtype {fp16,bf16} # default fp16 +--layout {nchw,nhwc} # default nchw +--method {auto,default,cblocked,nhwc,winograd_f4x3,winograd_f4x3_fused,winograd_f4x3_cblocked} +--metric {time,throughput} # default throughput +--no-bias # bench the bias=None code path +--miopen-solvers # detect MIOpen solver names (sweep mode; ~60-120s subprocess) +--show-kernel-name # include routed kernel name in single-shape output +``` + +Real-model shapes are pre-extracted (no torchvision/diffusers needed at +bench time). To add a new model, run `extract_conv_shapes.py` once and +merge its JSON output into `model_shapes.json`: + +```bash +python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \ + --model resnet50 # or sd35_vae / flux2_vae with --model-path +``` + +Tested on ROCm 7.2 / PyTorch `2.9.1+gitff65f5b` / Triton 3.7 (commit `23f4e522d`). + +--- + +## Documentation + +- **[`DESIGN.md`](DESIGN.md)** — architecture, per-kernel deep-dive, full + Winograd F(4,3) derivation (G/Bᵀ/Aᵀ matrices, 361× amplification + analysis, why Winograd is disabled for `C < 4`), the + `_select_3x3_method` heuristic, memory layouts and repacking, + numerical model, extension guide. + +--- + +## Repository layout + +``` +aiter/ops/triton/conv/ Kernel library + conv2d.py Public API + smart routing + _launch.py Grid setup + _select_3x3_method + _prepack.py Weight/input repack caches (LRU) + _utils.py Shape math, tolerance model + README.md, DESIGN.md + +aiter/ops/triton/_triton_kernels/conv/ @triton.jit kernels + (1x1, 3x3 cblocked, 3x3 NHWC, general, 5 Winograd kernels) + +op_tests/triton_tests/conv/ Pytest unit tests (CDNA CI runs this) + test_conv2d.py The only collected test file + _helpers.py TestSuite, registry, shape generators + +op_tests/op_benchmarks/triton/ + bench_conv2d.py Self-contained bench tool (single + sweep) + model_benchmarking_tool/ + extract_conv_shapes.py One-time offline shape extraction + model_shapes.json Pre-extracted conv shapes (resnet50, SD3.5, FLUX2) +``` diff --git a/aiter/ops/triton/conv/__init__.py b/aiter/ops/triton/conv/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiter/ops/triton/conv/_launch.py b/aiter/ops/triton/conv/_launch.py new file mode 100644 index 0000000000..e71b1d5171 --- /dev/null +++ b/aiter/ops/triton/conv/_launch.py @@ -0,0 +1,579 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import torch + +try: + import triton + import triton.language as tl +except Exception: + triton = None + tl = None + +from aiter.ops.triton.conv._utils import _out_hw, _is_winograd_eligible +from aiter.ops.triton._triton_kernels.conv.conv_1x1 import _conv2d_1x1_kernel +from aiter.ops.triton._triton_kernels.conv.conv_general import _conv2d_general_kernel +from aiter.ops.triton._triton_kernels.conv.conv_3x3 import ( + _conv2d_3x3_nhwc_kernel, + _conv2d_3x3_cblocked_kernel, +) +from aiter.ops.triton._triton_kernels.conv.conv_3x3_winograd_f4x3 import ( + _winograd_f4x3_input_transform_kernel, + _winograd_f4x3_cblocked_input_transform_kernel, + _winograd_f4x3_batched_gemm_kernel, + _winograd_f4x3_output_transform_kernel, + _winograd_f4x3_fused_gemm_output_kernel, +) + + +def _torch_dtype_to_tl(dtype): + """Map torch dtype to triton dtype for constexpr params.""" + if dtype == torch.float16: + return tl.float16 + elif dtype == torch.bfloat16: + return tl.bfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +def _select_3x3_method(N, C, H, W, K_out, stride, dilation): + """Pick the best 3x3 kernel method based on shape heuristics. + + Decision tree (from benchmark sweep on RDNA4): + 1. Non-Winograd-eligible (stride>1, dilation>1, or C<4) -> cblocked + 2. Winograd only wins when BOTH C and K >= 512 with enough tiles (T >= 98). + At 256x256 channels, cblocked is tied or slightly better. + 3. Among Winograd variants: WF4cb (NCHWc input) beats WF4 (NCHW input) + when T >= 392 (large batch * spatial gives more coalescing benefit). + Below that, WF4 is slightly faster (less repacking overhead). + """ + if not _is_winograd_eligible(3, 3, stride, dilation, C): + return "cblocked" + P, Q = _out_hw(H, W, 3, 3, stride, (1, 1), dilation) + tile_H = (P + 3) // 4 + tile_W = (Q + 3) // 4 + T = N * tile_H * tile_W + if C >= 512 and K_out >= 512 and T >= 98: + if T >= 392: + return "winograd_f4x3_cblocked" + return "winograd_f4x3" + return "cblocked" + + +def _layout_to_int(layout): + """Convert layout string to kernel int: 0=NCHW, 1=NHWC.""" + layout = layout.lower() + if layout not in ("nchw", "nhwc"): + raise ValueError(f"layout must be 'nchw' or 'nhwc', got '{layout}'") + return 0 if layout == "nchw" else 1 + + +def _launch_1x1( + x, + w_oihw, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + stride, + padding, + activation, + layout="nchw", +): + """Launch specialized 1x1 kernel. + layout: "nchw" or "nhwc" (case-insensitive). + """ + if triton is None: + raise RuntimeError("Triton not available") + + sh, sw = stride + ph, pw = padding + + w = w_oihw.squeeze(-1).squeeze(-1).contiguous() # [K_out, C] + layout = _layout_to_int(layout) + + def grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),) + + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else w.new_empty(1) + + M_total = N * P * Q + + _conv2d_1x1_kernel[grid]( + x, + w, + bias_arg, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + sh, + sw, + ph, + pw, + M_total, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + LAYOUT=layout, + ) + + +def _launch_3x3_nhwc( + x, + w_3x3, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + stride, + padding, + dilation, + activation, +): + """Launch specialized 3x3 NHWC kernel (hardcoded stride_c=1, stride_k=1).""" + if triton is None: + raise RuntimeError("Triton not available") + + sh, sw = stride + ph, pw = padding + dh, dw = dilation + + def grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),) + + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else w_3x3.new_empty(1) + + M_total = N * P * Q + + _conv2d_3x3_nhwc_kernel[grid]( + x, + w_3x3, + bias_arg, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + sh, + sw, + ph, + pw, + dh, + dw, + M_total, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + ) + + +def _launch_3x3_cblocked( + x_blocked, + w_3x3, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + Cb, + stride, + padding, + dilation, + activation, +): + """Launch specialized 3x3 kernel for channel-blocked input.""" + if triton is None: + raise RuntimeError("Triton not available") + + sh, sw = stride + ph, pw = padding + dh, dw = dilation + + def grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),) + + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else w_3x3.new_empty(1) + + M_total = N * P * Q + + _conv2d_3x3_cblocked_kernel[grid]( + x_blocked, + w_3x3, + bias_arg, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + Cb, + sh, + sw, + ph, + pw, + dh, + dw, + M_total, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + ) + + +def _launch_general( + x, + w_k, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + R, + S, + P, + Q, + K_pad, + stride, + padding, + dilation, + block_k, + activation, + layout="nchw", +): + """Launch general conv kernel. + layout: "nchw" or "nhwc" (case-insensitive). + """ + if triton is None: + raise RuntimeError("Triton not available") + + sh, sw = stride + ph, pw = padding + dh, dw = dilation + layout = _layout_to_int(layout) + + def grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(N * P * Q, BM) * triton.cdiv(K_out, BN),) + + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else w_k.new_empty(1) + + M_total = N * P * Q + + _conv2d_general_kernel[grid]( + x, + w_k, + bias_arg, + y, + N, + C, + H, + W_in, + K_out, + R, + S, + P, + Q, + K_pad, + sh, + sw, + ph, + pw, + dh, + dw, + M_total, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + LAYOUT=layout, + ) + + +def _launch_winograd_f4x3_fused( + x, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + layout="nchw", +): + """Launch Winograd F(4x4,3x3) with fused GEMM+output transform (2 kernels instead of 3).""" + if triton is None: + raise RuntimeError("Triton not available") + ph, pw = padding + tile_H = (P + 3) // 4 + tile_W = (Q + 3) // 4 + T = N * tile_H * tile_W + layout_int = _layout_to_int(layout) + + input_dtype = x.dtype + V = torch.empty((36, T, C_pad), device=x.device, dtype=input_dtype) + + # 1. Input transform + def input_grid_f4(meta): + return (T, triton.cdiv(C_pad, meta["BLOCK_C"])) + + _winograd_f4x3_input_transform_kernel[input_grid_f4]( + x, + V, + N, + C, + C_pad, + H, + W_in, + tile_H, + tile_W, + T, + ph, + pw, + INPUT_DTYPE=_torch_dtype_to_tl(input_dtype), + LAYOUT=layout_int, + ) + + # 2. Fused GEMM + output transform + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else x.new_empty(1) + + def fused_grid_f4(meta): + return (triton.cdiv(T, meta["BLOCK_T"]), triton.cdiv(K_out, meta["BLOCK_K"])) + + _winograd_f4x3_fused_gemm_output_kernel[fused_grid_f4]( + V, + U, + bias_arg, + y, + N, + K_out, + P, + Q, + C_pad, + tile_H, + tile_W, + T, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + LAYOUT=layout_int, + ) + + +def _launch_winograd_f4x3( + x, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + layout="nchw", +): + """Launch Winograd F(4x4,3x3) pipeline: input transform -> batched GEMM -> output transform.""" + if triton is None: + raise RuntimeError("Triton not available") + ph, pw = padding + tile_H = (P + 3) // 4 + tile_W = (Q + 3) // 4 + T = N * tile_H * tile_W + layout_int = _layout_to_int(layout) + + input_dtype = x.dtype + V = torch.empty((36, T, C_pad), device=x.device, dtype=input_dtype) + M = torch.empty((36, T, K_out), device=x.device, dtype=torch.float32) + + # 1. Input transform + def input_grid_f4(meta): + return (T, triton.cdiv(C_pad, meta["BLOCK_C"])) + + _winograd_f4x3_input_transform_kernel[input_grid_f4]( + x, + V, + N, + C, + C_pad, + H, + W_in, + tile_H, + tile_W, + T, + ph, + pw, + INPUT_DTYPE=_torch_dtype_to_tl(input_dtype), + LAYOUT=layout_int, + ) + + # 2. Batched GEMM + def gemm_grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(T, BM) * triton.cdiv(K_out, BN), 36) + + _winograd_f4x3_batched_gemm_kernel[gemm_grid]( + V, + U, + M, + T, + K_out, + C_pad, + ) + + # 3. Output transform + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else x.new_empty(1) + + def output_grid_f4(meta): + return (T, triton.cdiv(K_out, meta["BLOCK_K"])) + + _winograd_f4x3_output_transform_kernel[output_grid_f4]( + M, + bias_arg, + y, + N, + K_out, + P, + Q, + tile_H, + tile_W, + T, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + LAYOUT=layout_int, + ) + + +def _launch_winograd_f4x3_cblocked( + x_blocked, + C_pad_blocked, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + block_k, +): + """Launch Winograd F(4x4,3x3) with NCHWc input layout: cblocked input transform -> batched GEMM -> output transform.""" + if triton is None: + raise RuntimeError("Triton not available") + ph, pw = padding + tile_H = (P + 3) // 4 + tile_W = (Q + 3) // 4 + T = N * tile_H * tile_W + + Cb = block_k + input_dtype = x_blocked.dtype + V = torch.empty((36, T, C_pad), device=x_blocked.device, dtype=input_dtype) + M = torch.empty((36, T, K_out), device=x_blocked.device, dtype=torch.float32) + + # 1. Cblocked input transform + def input_grid_f4(meta): + return (T, triton.cdiv(C_pad, meta["BLOCK_C"])) + + _winograd_f4x3_cblocked_input_transform_kernel[input_grid_f4]( + x_blocked, + V, + N, + C, + C_pad, + H, + W_in, + tile_H, + tile_W, + T, + ph, + pw, + Cb, + INPUT_DTYPE=_torch_dtype_to_tl(input_dtype), + ) + + def gemm_grid(meta): + BM = meta["BLOCK_M"] + BN = meta["BLOCK_N"] + return (triton.cdiv(T, BM) * triton.cdiv(K_out, BN), 36) + + _winograd_f4x3_batched_gemm_kernel[gemm_grid]( + V, + U, + M, + T, + K_out, + C_pad, + ) + + ACT_MAP = {"none": 0, "relu": 1, "relu6": 2, "gelu": 3} + bias_arg = bias_fp32 if bias_fp32 is not None else x_blocked.new_empty(1) + + def output_grid_f4(meta): + return (T, triton.cdiv(K_out, meta["BLOCK_K"])) + + _winograd_f4x3_output_transform_kernel[output_grid_f4]( + M, + bias_arg, + y, + N, + K_out, + P, + Q, + tile_H, + tile_W, + T, + HAS_BIAS=1 if bias_fp32 is not None else 0, + ACT_TYPE=ACT_MAP.get(activation, 0), + ) diff --git a/aiter/ops/triton/conv/_prepack.py b/aiter/ops/triton/conv/_prepack.py new file mode 100644 index 0000000000..d586034a56 --- /dev/null +++ b/aiter/ops/triton/conv/_prepack.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import os +from collections import OrderedDict +from typing import Dict +import torch + +from aiter.ops.triton.conv._utils import BLOCK_K, _storage_ptr + +_PACK_CACHE_MAXSIZE = int(os.environ.get("AITER_TRITON_CONV_PACK_CACHE_SIZE", "256")) + + +class _LRUPackCache: + """Bounded LRU for weight prepacks. Stores (src_tensor, item) — the + strong ref to src keeps storage alive so the storage_ptr in the key + cannot be reused by a different tensor while this entry lives.""" + + def __init__(self, maxsize: int = _PACK_CACHE_MAXSIZE): + self._d: "OrderedDict[tuple, tuple]" = OrderedDict() + self._max = max(1, maxsize) + + def get(self, key): + entry = self._d.get(key) + if entry is None: + return None + self._d.move_to_end(key) + return entry + + def put(self, key, src, item): + self._d[key] = (src, item) + self._d.move_to_end(key) + while len(self._d) > self._max: + self._d.popitem(last=False) + + def clear(self): + self._d.clear() + + def __len__(self): + return len(self._d) + + +_PACK_CACHE = _LRUPackCache() +_PACK_CACHE_3x3 = _LRUPackCache() +# Input pack — kept as single-entry dict by design: in real inference each +# layer's input is a unique intermediate activation that won't be reused, +# and the bench clears this per-call to model per-batch repack cost. +_PACK_CACHE_CBLOCKED: Dict = {} +_PACK_CACHE_WINOGRAD_F4X3 = _LRUPackCache() + + +def prepack_oihw_to_kmajor(w_oihw: torch.Tensor, block_k: int = BLOCK_K): + K_out, C, R, S = w_oihw.shape + K_red = C * R * S + K_pad = ((K_red + block_k - 1) // block_k) * block_k + w_rs = w_oihw.reshape(K_out, K_red) + if K_pad != K_red: + pad = torch.zeros( + (K_out, K_pad - K_red), device=w_oihw.device, dtype=w_oihw.dtype + ) + w_rs = torch.cat([w_rs, pad], dim=1) + return w_rs.contiguous(), (K_out, K_pad) + + +def get_or_make_weight_pack(w_oihw: torch.Tensor, block_k: int = BLOCK_K): + key = ( + _storage_ptr(w_oihw), + tuple(w_oihw.shape), + w_oihw.dtype, + block_k, + int(getattr(w_oihw, "_version", 0)), + ) + entry = _PACK_CACHE.get(key) + if entry is not None: + return entry[1] + item = prepack_oihw_to_kmajor(w_oihw, block_k) + _PACK_CACHE.put(key, w_oihw, item) + return item + + +def prepack_oihw_to_3x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K): + """Pack weights as [K_out, 9, C_pad] for 3x3 specialized kernel.""" + K_out, C, R, S = w_oihw.shape + assert R == 3 and S == 3 + C_pad = ((C + block_c - 1) // block_c) * block_c + w_rs = w_oihw.reshape(K_out, C, 9).permute(0, 2, 1).contiguous() # [K_out, 9, C] + if C_pad != C: + pad = torch.zeros( + (K_out, 9, C_pad - C), device=w_oihw.device, dtype=w_oihw.dtype + ) + w_rs = torch.cat([w_rs, pad], dim=2) + return w_rs.contiguous(), (K_out, C_pad) + + +def get_or_make_weight_pack_3x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K): + key = ( + _storage_ptr(w_oihw), + tuple(w_oihw.shape), + w_oihw.dtype, + block_c, + int(getattr(w_oihw, "_version", 0)), + ) + cached = _PACK_CACHE_3x3.get(key) + if cached is not None: + return cached[1] + item = prepack_oihw_to_3x3(w_oihw, block_c) + _PACK_CACHE_3x3.put(key, w_oihw, item) + return item + + +def prepack_nchw_to_cblocked(x: torch.Tensor, block_c: int = BLOCK_K): + """Pack NCHW input into channel-blocked layout [N, C_blocks, H, W, Cb]. + + Within each block of Cb channels, data is contiguous (stride=1). + """ + N, C, H, W = x.shape + Cb = block_c + C_blocks = (C + Cb - 1) // Cb + C_pad = C_blocks * Cb + + if C_pad != C: + x_padded = torch.zeros((N, C_pad, H, W), device=x.device, dtype=x.dtype) + x_padded[:, :C, :, :] = x + else: + x_padded = x + + x_blocked = ( + x_padded.reshape(N, C_blocks, Cb, H, W).permute(0, 1, 3, 4, 2).contiguous() + ) + return x_blocked, C_pad + + +def get_or_make_input_pack_cblocked(x: torch.Tensor, block_c: int = BLOCK_K): + key = ( + _storage_ptr(x), + tuple(x.shape), + x.dtype, + block_c, + int(getattr(x, "_version", 0)), + ) + cached = _PACK_CACHE_CBLOCKED.get(key) + if cached is not None: + src_ref, item = cached + if src_ref is not None and _storage_ptr(src_ref) == key[0]: + return item + item = prepack_nchw_to_cblocked(x, block_c) + _PACK_CACHE_CBLOCKED.clear() + _PACK_CACHE_CBLOCKED[key] = (x, item) + return item + + +def prepack_winograd_filter_f4x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K): + """Transform 3x3 filters for Winograd F(4x4,3x3). G @ g @ G^T for each (k,c). + Input: [K_out, C, 3, 3] fp16. Output: [36, K_out, C_pad] fp16.""" + K_out, C, R, S = w_oihw.shape + assert R == 3 and S == 3 + C_pad = ((C + block_c - 1) // block_c) * block_c + # G matrix (6x3) + G = torch.tensor( + [ + [1.0 / 4, 0.0, 0.0], + [-1.0 / 6, -1.0 / 6, -1.0 / 6], + [-1.0 / 6, 1.0 / 6, -1.0 / 6], + [1.0 / 24, 1.0 / 12, 1.0 / 6], + [1.0 / 24, -1.0 / 12, 1.0 / 6], + [0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=w_oihw.device, + ) + + g = w_oihw.float() # [K_out, C, 3, 3] + u = torch.einsum("ij,kcjl,lm->kcim", G, g, G.t()) + u = u.reshape(K_out, C, 36).permute(2, 0, 1).contiguous() + if C_pad != C: + pad = torch.zeros( + (36, K_out, C_pad - C), device=w_oihw.device, dtype=torch.float32 + ) + u = torch.cat([u, pad], dim=2) + return u.to(w_oihw.dtype).contiguous(), (K_out, C_pad) + + +def get_or_make_winograd_filter_f4x3(w_oihw: torch.Tensor, block_c: int = BLOCK_K): + key = ( + _storage_ptr(w_oihw), + tuple(w_oihw.shape), + w_oihw.dtype, + block_c, + int(getattr(w_oihw, "_version", 0)), + ) + cached = _PACK_CACHE_WINOGRAD_F4X3.get(key) + if cached is not None: + return cached[1] + item = prepack_winograd_filter_f4x3(w_oihw, block_c) + _PACK_CACHE_WINOGRAD_F4X3.put(key, w_oihw, item) + return item diff --git a/aiter/ops/triton/conv/_utils.py b/aiter/ops/triton/conv/_utils.py new file mode 100644 index 0000000000..9d0850e8cf --- /dev/null +++ b/aiter/ops/triton/conv/_utils.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +import torch +import torch.nn.functional as F + +# Channel padding granularity for prepacked weights/inputs. Must align with the +# BLOCK_K autotune candidates in _triton_kernels/conv/helpers.py — change with care. +BLOCK_K = 64 + + +def dynamic_conv_tolerances(dtype: torch.dtype, K_red: int, ref: torch.Tensor): + eps = { + torch.float16: 2**-10, + torch.bfloat16: 2**-7, + torch.float32: 2**-23, + }.get(dtype, 2**-10) + rtol = 6e-3 if K_red < 1024 else (8e-3 if K_red < 4096 else 1.2e-2) + # Error model: fp16 inputs multiplied pairwise have eps relative error per product. + # Accumulated in fp32 over K_red terms, max absolute error grows as ~eps * sqrt(K_red). + # The 10x multiplier covers worst-case accumulation ordering differences + # between our Triton kernels and PyTorch reference. + atol = max(eps * 8, 10.0 * eps * (K_red**0.5)) + return rtol, atol + + +def flops_conv(N, C, K_out, R, S, P, Q): + return 2.0 * N * P * Q * K_out * C * R * S + + +def _out_hw(H, W, R, S, stride, padding, dilation): + sh, sw = stride + ph, pw = padding + dh, dw = dilation + P = (H + 2 * ph - dh * (R - 1) - 1) // sh + 1 + Q = (W + 2 * pw - dw * (S - 1) - 1) // sw + 1 + return P, Q + + +def _storage_ptr(t: torch.Tensor) -> int: + return ( + t.untyped_storage().data_ptr() + if hasattr(t, "untyped_storage") + else t.storage().data_ptr() + ) + + +def _is_1x1_conv(R, S, dilation): + """Check if this is a 1x1 convolution (no spatial reduction in kernel).""" + return R == 1 and S == 1 and dilation == (1, 1) + + +def _is_3x3_conv(R, S): + """Check if this is a 3x3 convolution.""" + return R == 3 and S == 3 + + +def _is_winograd_eligible(R, S, stride, dilation, C=None): + if not (R == 3 and S == 3 and stride == (1, 1) and dilation == (1, 1)): + return False + # F(4,3) output transform amplifies bf16 rounding by up to 361x (AT row3 L1=19). + # With very few input channels the tolerance budget is too small to absorb this. + if C is not None and C < 4: + return False + return True + + +def _winograd_tolerances(dtype, K_red, ref, variant="f4x3"): + """Return (rtol, atol) for Winograd F(4x4,3x3) correctness checks. + Winograd transforms amplify fp16 rounding errors: + - F(4x4,3x3): coefficients up to ±8, significant amplification + """ + rtol, atol = dynamic_conv_tolerances(dtype, K_red, ref) + if variant == "f4x3": + rtol *= 6.0 + atol = max(atol * 6.0, 0.6) + return rtol, atol + + +def apply_activation(y: torch.Tensor, activation: str): + if activation == "relu": + return F.relu(y) + if activation == "relu6": + return torch.clamp(y, 0, 6) + if activation == "gelu": + return F.gelu(y, approximate="tanh") + return y diff --git a/aiter/ops/triton/conv/conv2d.py b/aiter/ops/triton/conv/conv2d.py new file mode 100644 index 0000000000..afda01cbf9 --- /dev/null +++ b/aiter/ops/triton/conv/conv2d.py @@ -0,0 +1,640 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +from typing import Optional +import torch + +from aiter.ops.triton.utils.logger import AiterTritonLogger +from aiter.ops.triton.conv._utils import ( + BLOCK_K, + _out_hw, + _is_1x1_conv, + _is_3x3_conv, + _is_winograd_eligible, +) +from aiter.ops.triton.conv._prepack import ( + get_or_make_weight_pack, + get_or_make_weight_pack_3x3, + get_or_make_input_pack_cblocked, + get_or_make_winograd_filter_f4x3, +) +from aiter.ops.triton.conv._launch import ( + _launch_1x1, + _launch_3x3_nhwc, + _launch_3x3_cblocked, + _launch_general, + _launch_winograd_f4x3, + _launch_winograd_f4x3_cblocked, + _launch_winograd_f4x3_fused, + _select_3x3_method, +) + +_LOGGER = AiterTritonLogger() + +# Tracks the last Triton kernel selected by conv2d_nchw smart routing. +_last_triton_kernel: Optional[str] = None + + +def conv2d( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + layout="nchw", +): + """Forward 2-D conv on AMD ROCm via Triton. Drop-in for the forward of + ``torch.nn.functional.conv2d`` (no backward). + + A shape-driven router picks among five kernel families (1x1, 3x3 cblocked, + 3x3 NHWC, Winograd F(4x4,3x3), general) per call. + + Inputs must be fp16 or bf16. ``layout="nhwc"`` runs an NHWC-native kernel + with no internal layout conversion. + + ``out_dtype=None`` (default) returns output in the input dtype, matching + ``torch.nn.Conv2d`` semantics. + + Notes + ----- + - Only ``groups=1`` (depthwise/grouped raises ``AssertionError``). + - Only ``padding_mode="zeros"`` (no reflect/replicate/circular). + - ``bias=None`` skips the with-bias kernel path; passing a zero tensor + instead routes through the with-bias kernel and times differently. + """ + if x.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"conv2d only supports fp16 and bf16 inputs, got {x.dtype}") + if out_dtype is None: + out_dtype = x.dtype + elif out_dtype not in (torch.float16, torch.bfloat16): + raise ValueError( + f"out_dtype must be torch.float16 or torch.bfloat16, got {out_dtype}" + ) + layout = layout.lower() + if layout not in ("nchw", "nhwc"): + raise ValueError(f"layout must be 'nchw' or 'nhwc', got '{layout}'") + + _LOGGER.info( + f"CONV2D: x={tuple(x.shape)} w={tuple(w_oihw.shape)} stride={stride} " + f"padding={padding} dilation={dilation} layout={layout} " + f"dtype={x.dtype} out_dtype={out_dtype} bias={'yes' if bias is not None else 'no'} " + f"act={activation}" + ) + + if layout == "nhwc": + return conv2d_nhwc( + x, w_oihw, bias, stride, padding, dilation, activation, out_dtype + ) + else: + return conv2d_nchw( + x, w_oihw, bias, stride, padding, dilation, activation, out_dtype + ) + + +def conv2d_winograd_f4x3( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, + layout="nchw", +): + """NCHW/NHWC conv2d using Winograd F(4x4,3x3). Raises ValueError for non-eligible convs.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if not _is_winograd_eligible(R, S, stride, dilation, C): + raise ValueError( + f"conv2d_winograd_f4x3 requires 3x3 kernel with stride=1, dilation=1, " + f"and C >= 4 (F(4,3) output transform amplifies rounding by up to " + f"361x; C<4 has too few reduction terms to absorb it), " + f"got {R}x{S} stride={stride} dilation={dilation} C={C}" + ) + + if layout == "nhwc": + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to( + memory_format=torch.channels_last + ) + else: + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k) + _launch_winograd_f4x3( + x, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + layout=layout, + ) + return y + + +def conv2d_winograd_f4x3_cblocked( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """NCHW conv2d using Winograd F(4x4,3x3) with NCHWc input layout for coalesced loads. + Raises ValueError for non-eligible convs.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if not _is_winograd_eligible(R, S, stride, dilation, C): + raise ValueError( + f"conv2d_winograd_f4x3_cblocked requires 3x3 kernel with stride=1, dilation=1, " + f"and C >= 4 (F(4,3) output transform amplifies rounding by up to " + f"361x; C<4 has too few reduction terms to absorb it), " + f"got {R}x{S} stride={stride} dilation={dilation} C={C}" + ) + + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k) + x_blocked, C_pad_blocked = get_or_make_input_pack_cblocked(x, block_k) + _launch_winograd_f4x3_cblocked( + x_blocked, + C_pad_blocked, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + block_k, + ) + return y + + +def conv2d_winograd_f4x3_fused( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """NCHW conv2d using Winograd F(4x4,3x3) with fused GEMM+output transform. + Raises ValueError for non-eligible convs.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if not _is_winograd_eligible(R, S, stride, dilation, C): + raise ValueError( + f"conv2d_winograd_f4x3_fused requires 3x3 kernel with stride=1, dilation=1, " + f"and C >= 4 (F(4,3) output transform amplifies rounding by up to " + f"361x; C<4 has too few reduction terms to absorb it), " + f"got {R}x{S} stride={stride} dilation={dilation} C={C}" + ) + + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + U, (_, C_pad) = get_or_make_winograd_filter_f4x3(w_oihw.contiguous(), block_k) + _launch_winograd_f4x3_fused( + x, + U, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + padding, + activation, + ) + return y + + +def conv2d_1x1( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, + layout="nchw", +): + """NCHW/NHWC conv2d for 1x1 kernels. Raises ValueError for non-1x1.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + if not _is_1x1_conv(R, S, dilation): + raise ValueError(f"conv2d_1x1 requires 1x1 kernel, got {R}x{S}") + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if layout == "nhwc": + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to( + memory_format=torch.channels_last + ) + else: + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + _launch_1x1( + x, + w_oihw.contiguous(), + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + stride, + padding, + activation, + layout=layout, + ) + return y + + +def conv2d_general( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, + layout="nchw", +): + """NCHW/NHWC conv2d using general kernel with prepacked weights (5x5, 7x7, etc.).""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if layout == "nhwc": + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to( + memory_format=torch.channels_last + ) + else: + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + w_k, (_, K_pad) = get_or_make_weight_pack(w_oihw.contiguous(), block_k) + _launch_general( + x, + w_k, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + R, + S, + P, + Q, + K_pad, + stride, + padding, + dilation, + block_k, + activation, + layout=layout, + ) + return y + + +def conv2d_nhwc_3x3( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """NHWC conv2d for 3x3 kernels. Raises ValueError for non-3x3.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + if not _is_3x3_conv(R, S): + raise ValueError(f"conv2d_nhwc_3x3 requires 3x3 kernel, got {R}x{S}") + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype).to( + memory_format=torch.channels_last + ) + bias_fp32 = bias.float().contiguous() if bias is not None else None + w_3x3, (_, C_pad) = get_or_make_weight_pack_3x3(w_oihw.contiguous(), block_k) + _launch_3x3_nhwc( + x, + w_3x3, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + stride, + padding, + dilation, + activation, + ) + return y + + +def conv2d_nchw( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """Hybrid NCHW conv2d: routes to specialized 1x1, 3x3, or general kernel.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + + global _last_triton_kernel + if _is_1x1_conv(R, S, dilation): + _last_triton_kernel = "_conv2d_1x1_kernel" + return conv2d_1x1( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + layout="nchw", + ) + elif _is_3x3_conv(R, S): + method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation) + if method == "winograd_f4x3_cblocked": + _last_triton_kernel = "_winograd_f4x3_cblocked_* (3 kern)" + return conv2d_winograd_f4x3_cblocked( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + ) + elif method == "winograd_f4x3": + _last_triton_kernel = "_winograd_f4x3_* (3 kernels)" + return conv2d_winograd_f4x3( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + ) + else: + _last_triton_kernel = "_conv2d_3x3_cblocked_kernel" + return conv2d_nchw_cblocked( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + ) + else: + _last_triton_kernel = "_conv2d_general_kernel" + return conv2d_general( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + layout="nchw", + ) + + +def conv2d_nhwc( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """Conv2d with NHWC (channels-last) input and output. + + Input x can be NCHW or NHWC — it will be converted to channels_last. + Output y is allocated as channels_last (NHWC-contiguous) and returned + in logical NCHW shape with channels_last strides. + """ + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + x = x.to(memory_format=torch.channels_last) + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + + global _last_triton_kernel + if _is_1x1_conv(R, S, dilation): + _last_triton_kernel = "_conv2d_1x1_kernel" + return conv2d_1x1( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + layout="nhwc", + ) + elif _is_3x3_conv(R, S): + method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation) + if method in ("winograd_f4x3", "winograd_f4x3_cblocked"): + _last_triton_kernel = "_winograd_f4x3_* (3 kernels)" + return conv2d_winograd_f4x3( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + layout="nhwc", + ) + else: + _last_triton_kernel = "_conv2d_3x3_nhwc_kernel" + return conv2d_nhwc_3x3( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + ) + else: + _last_triton_kernel = "_conv2d_general_kernel" + return conv2d_general( + x, + w_oihw, + bias, + stride, + padding, + dilation, + activation, + out_dtype, + block_k, + layout="nhwc", + ) + + +def conv2d_nchw_cblocked( + x, + w_oihw, + bias=None, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + activation="none", + out_dtype: Optional[torch.dtype] = None, + block_k=BLOCK_K, +): + """NCHW conv2d with channel-blocked input packing for 3x3 kernels. + Raises ValueError for non-3x3.""" + assert x.is_cuda and w_oihw.is_cuda + if out_dtype is None: + out_dtype = x.dtype + N, C, H, W_in = x.shape + K_out, Cw, R, S = w_oihw.shape + assert Cw == C + P, Q = _out_hw(H, W_in, R, S, stride, padding, dilation) + + if not _is_3x3_conv(R, S): + raise ValueError(f"conv2d_nchw_cblocked requires 3x3 kernel, got {R}x{S}") + + y = torch.empty((N, K_out, P, Q), device=x.device, dtype=out_dtype) + bias_fp32 = bias.float().contiguous() if bias is not None else None + w_3x3, (_, C_pad) = get_or_make_weight_pack_3x3(w_oihw.contiguous(), block_k) + Cb = block_k # packing block size matches weight padding block + x_blocked, C_pad_x = get_or_make_input_pack_cblocked(x, Cb) + # Ensure channel padding is consistent + assert ( + C_pad_x == C_pad + ), f"Channel padding mismatch: input {C_pad_x} vs weight {C_pad}" + _launch_3x3_cblocked( + x_blocked, + w_3x3, + bias_fp32, + y, + N, + C, + H, + W_in, + K_out, + P, + Q, + C_pad, + Cb, + stride, + padding, + dilation, + activation, + ) + return y diff --git a/op_tests/op_benchmarks/triton/bench_conv2d.py b/op_tests/op_benchmarks/triton/bench_conv2d.py new file mode 100644 index 0000000000..10cd8254ef --- /dev/null +++ b/op_tests/op_benchmarks/triton/bench_conv2d.py @@ -0,0 +1,889 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +"""Benchmark aiter.ops.triton.conv.conv2d. + +Two modes: +- Single-shape (pass --N --C --H --W --K --R --S [--stride ...] etc.) + Bench one shape and emit a single key=value result line. Useful for + ad-hoc one-off measurements or scripting around a specific shape. +- Sweep (no --N). Iterates either the built-in default shape list or the + conv2d shapes for a model in model_shapes.json (--model NAME), and + prints three box-drawn tables at the end: + 1. LAYER-BY-LAYER BENCHMARK (per-layer Tri vs Torch + correctness) + 2. MIOpen SOLVER SUMMARY (only when --miopen-solvers is passed) + 3. OVERALL PERFORMANCE (mean/median/aggregate TFLOPS, layer wins) + +Each shape is timed with triton.testing.do_bench against +torch.nn.functional.conv2d as the reference backend (MIOpen on AMD), +and a correctness check compares the Triton output against F.conv2d +within the same tolerance model the test suite uses. + +The selected Triton kernel name is captured for every shape. The MIOpen +solver name is captured only when --miopen-solvers is passed, because +detection requires a separate subprocess with MIOPEN_LOG_LEVEL=6 (~60s +fixed startup cost). + +For NCHW non-1x1 shapes, kernel+repack timing is also captured: the input +prepack cache (_PACK_CACHE_CBLOCKED) is cleared before each timing call, +giving the steady-state inference cost when input layout changes per call. + +No model loading at runtime — model shapes come from the pre-extracted +model_shapes.json (see extract_conv_shapes.py for how to regenerate it). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import statistics +import subprocess +import sys +from typing import Optional + +import torch +import torch.nn.functional as F +import triton + +import aiter.ops.triton.conv.conv2d as _ops_module +from aiter.ops.triton.conv._utils import ( + flops_conv, + _out_hw, + _is_1x1_conv, + _is_3x3_conv, + dynamic_conv_tolerances, + _winograd_tolerances, +) +from aiter.ops.triton.conv._prepack import _PACK_CACHE_CBLOCKED +from aiter.ops.triton.conv.conv2d import ( + conv2d, + conv2d_nchw, + conv2d_nchw_cblocked, + conv2d_nhwc, + conv2d_winograd_f4x3, + conv2d_winograd_f4x3_fused, + conv2d_winograd_f4x3_cblocked, +) + +METHODS = { + "auto": conv2d, + "default": conv2d_nchw, + "cblocked": conv2d_nchw_cblocked, + "nhwc": conv2d_nhwc, + "winograd_f4x3": conv2d_winograd_f4x3, + "winograd_f4x3_fused": conv2d_winograd_f4x3_fused, + "winograd_f4x3_cblocked": conv2d_winograd_f4x3_cblocked, +} + + +# Default sweep shapes — same set as the test edge cases. Kept in sync by hand +# (12 tuples; trivial duplication, see _helpers.get_edge_case_shapes). +DEFAULT_SHAPES = [ + # (N, C, H, W, K, R, S, stride, padding, dilation, desc) + (1, 3, 7, 7, 8, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 same padding"), + (1, 3, 8, 8, 16, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 stride1"), + (2, 16, 32, 32, 32, 3, 3, (2, 2), (1, 1), (1, 1), "stride2"), + (2, 32, 17, 23, 64, 5, 5, (2, 2), (2, 2), (1, 1), "odd dims + pad"), + (4, 64, 28, 28, 128, 3, 3, (1, 1), (0, 0), (2, 2), "dilation2"), + (2, 512, 7, 7, 1024, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 large channels"), + (1, 3, 112, 112, 64, 7, 7, (2, 2), (3, 3), (1, 1), "7x7 large spatial"), + (1, 1, 16, 16, 16, 3, 3, (1, 1), (1, 1), (1, 1), "single input channel"), + (2, 64, 8, 8, 64, 3, 3, (1, 1), (1, 1), (1, 1), "small spatial 3x3"), + (1, 128, 4, 4, 256, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 tiny spatial"), + (2, 32, 32, 32, 32, 3, 3, (1, 1), (0, 0), (1, 1), "3x3 no padding"), + (2, 64, 28, 28, 128, 3, 3, (2, 2), (1, 1), (1, 1), "3x3 stride2 standard"), +] + + +# MIOpen solver names → human-readable algorithm types (matches old suite.py). +MIOPEN_ALGO_MAP = { + "ConvWinoFuryRxS<2-3>": "Winograd Fury F(2,3)", + "ConvBinWinogradRxSf3x2": "Winograd F(3x3,2x2) binary", + "GemmFwd1x1_0_1": "GEMM (no workspace)", + "GemmFwdRest": "GEMM fallback", +} + + +# ---------------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------------- + + +def _torch_dtype(s: str) -> torch.dtype: + if s == "fp16": + return torch.float16 + if s == "bf16": + return torch.bfloat16 + raise ValueError(f"unsupported dtype: {s}") + + +def _check_close(got, ref, dtype, K_red, is_winograd: bool) -> bool: + """Tolerance-aware correctness check (same model as the pytest suite).""" + if is_winograd: + rtol, atol = _winograd_tolerances(dtype, K_red, ref, "f4x3") + else: + rtol, atol = dynamic_conv_tolerances(dtype, K_red, ref) + try: + torch.testing.assert_close(got.float(), ref.float(), rtol=rtol, atol=atol) + return True + except AssertionError: + return False + + +def _kernel_type_tag(R: int, S: int, dilation: tuple) -> str: + if _is_1x1_conv(R, S, dilation): + return "[1x1]" + if _is_3x3_conv(R, S): + return "[3x3]" + return "[general]" + + +def _shape_str(N, C, H, W, K, R, S) -> str: + return f"({N},{C},{H},{W})→{K}/{R}x{S}" + + +# ---------------------------------------------------------------------------- +# MIOpen solver detection (subprocess-based, opt-in via --miopen-solvers) +# ---------------------------------------------------------------------------- + + +_miopen_solver_cache: dict = {} + + +def precompute_miopen_solvers(shapes, dtype: torch.dtype) -> None: + """Detect MIOpen solver per shape via a single subprocess. + + Spawns Python with MIOPEN_LOG_LEVEL=6, runs F.conv2d for each shape, + and parses stderr for "Chosen Algorithm:" lines. SHAPE_DONE markers + on stderr disambiguate which "Chosen Algorithm" line belongs to which + shape (positional alignment is unreliable when MIOpen logs vary). + + Cache populated as a side effect. Use _get_miopen_solver to read. + """ + global _miopen_solver_cache + + unique = [] + seen = set() + for entry in shapes: + N, C, H, W, K, R, S, stride, padding, dilation = entry[:10] + s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride) + p_h, p_w = padding if isinstance(padding, tuple) else (padding, padding) + d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation) + key = (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w) + if key not in seen: + seen.add(key) + unique.append(key) + if not unique: + return + + dtype_str = { + torch.float16: "torch.float16", + torch.bfloat16: "torch.bfloat16", + }.get(dtype, "torch.float16") + + lines = [ + "import os, sys", + "os.environ['MIOPEN_LOG_LEVEL']='6'", + "import torch, torch.nn.functional as F", + ] + for i, (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w) in enumerate(unique): + lines.append(f"# shape {i}") + lines.append(f"x=torch.randn({N},{C},{H},{W},device='cuda',dtype={dtype_str})") + lines.append(f"w=torch.randn({K},{C},{R},{S},device='cuda',dtype={dtype_str})") + lines.append( + f"F.conv2d(x,w,None,stride=({s_h},{s_w}),padding=({p_h},{p_w}),dilation=({d_h},{d_w}))" + ) + lines.append("torch.cuda.synchronize()") + lines.append(f"sys.stderr.write('SHAPE_DONE:{i}\\n');sys.stderr.flush()") + script = "\n".join(lines) + + try: + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=120, + env={**os.environ, "MIOPEN_LOG_LEVEL": "6"}, + ) + except subprocess.TimeoutExpired: + print( + f"[miopen-detect] WARNING: subprocess timed out after 120s; " + f"MIOpen solver column will be empty for {len(unique)} shape(s).", + file=sys.stderr, + ) + return + except Exception as e: + print( + f"[miopen-detect] WARNING: subprocess failed ({e!r}); " + f"MIOpen solver column will be empty.", + file=sys.stderr, + ) + return + + if result.returncode != 0: + tail = "\n".join(result.stderr.strip().split("\n")[-5:]) + print( + f"[miopen-detect] WARNING: subprocess exited with code " + f"{result.returncode}; MIOpen solver column will be empty.\n" + f" Last stderr lines:\n{tail}", + file=sys.stderr, + ) + return + + pending: Optional[str] = None + attributed: dict = {} + orphan = 0 + shape_done_re = re.compile(r"^SHAPE_DONE:(\d+)\s*$") + chosen_re = re.compile(r"Chosen Algorithm:\s*(\S+)") + for line in result.stderr.split("\n"): + m = chosen_re.search(line) + if m: + pending = m.group(1).strip(" ,") + continue + m = shape_done_re.match(line) + if m: + idx = int(m.group(1)) + if pending is not None: + attributed[idx] = pending + else: + orphan += 1 + pending = None + for idx, solver in attributed.items(): + _miopen_solver_cache[unique[idx]] = solver + + missing = len(unique) - len(attributed) + if missing > 0: + print( + f"[miopen-detect] WARNING: {missing}/{len(unique)} shape(s) have no " + f"MIOpen solver detected ({orphan} marker(s) had no preceding " + f"'Chosen Algorithm' line). Common causes: MIOpen log format changed, " + f"MIOPEN_LOG_LEVEL was overridden, or the shape failed in the subprocess.", + file=sys.stderr, + ) + + +def _get_miopen_solver(N, C, H, W, K, R, S, stride, padding, dilation) -> str: + s_h, s_w = stride if isinstance(stride, tuple) else (stride, stride) + p_h, p_w = padding if isinstance(padding, tuple) else (padding, padding) + d_h, d_w = dilation if isinstance(dilation, tuple) else (dilation, dilation) + return _miopen_solver_cache.get( + (N, C, H, W, K, R, S, s_h, s_w, p_h, p_w, d_h, d_w), "" + ) + + +# ---------------------------------------------------------------------------- +# Per-shape bench (returns rich dict) +# ---------------------------------------------------------------------------- + + +def bench_one_shape( + N: int, + C: int, + H: int, + W: int, + K: int, + R: int, + S: int, + stride: tuple, + padding: tuple, + dilation: tuple, + dtype: torch.dtype, + method: str, + layout: str, + bias: bool = True, + measure_repack: bool = True, +) -> dict: + """Time + correctness-check one shape. Returns a dict with full metadata. + + Keys: ms_tri, ms_torch, ms_tri_e2e (or None), tflops_tri, tflops_torch, + tflops_tri_e2e (or None), correct, kernel_name, has_repack, flops. + + measure_repack: if True (default), additionally times the kernel+input-repack + path for NCHW non-1x1 shapes by clearing _PACK_CACHE_CBLOCKED between calls. + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available; conv2d bench requires a GPU.") + + P, Q = _out_hw(H, W, R, S, stride, padding, dilation) + if P < 1 or Q < 1: + raise ValueError( + f"output spatial dims < 1 for shape " + f"N={N} C={C} H={H} W={W} K={K} R={R} S={S} " + f"stride={stride} padding={padding} dilation={dilation}" + ) + + device = "cuda" + x = torch.randn((N, C, H, W), device=device, dtype=dtype) + w = torch.randn((K, C, R, S), device=device, dtype=dtype) + b = torch.randn((K,), device=device, dtype=dtype) if bias else None + + if layout == "nhwc": + x_in = x.to(memory_format=torch.channels_last) + kernel_fn = conv2d_nhwc + else: + x_in = x + if method not in METHODS: + raise ValueError(f"unknown method: {method}; choices: {list(METHODS)}") + kernel_fn = METHODS[method] + + def run_triton(): + return kernel_fn( + x_in, + w, + b, + stride, + padding, + dilation, + activation="none", + out_dtype=dtype, + ) + + def run_torch(): + return F.conv2d(x_in, w, b, stride=stride, padding=padding, dilation=dilation) + + # One run: captures the output (for correctness) AND _last_triton_kernel. + y_tri = run_triton() + torch.cuda.synchronize() + kernel_name = getattr(_ops_module, "_last_triton_kernel", "") or "" + is_winograd = "winograd" in kernel_name.lower() or "wino" in kernel_name.lower() + + y_ref = run_torch() + correct = _check_close( + y_tri, y_ref, dtype, K_red=C * R * S, is_winograd=is_winograd + ) + + ms_tri = triton.testing.do_bench(run_triton, warmup=15, rep=50) + ms_th = triton.testing.do_bench(run_torch, warmup=15, rep=50) + + # Kernel+repack timing: clear input pack cache between calls. Only NCHW + # non-1x1 — 1x1 takes raw weights (no repacking) and NHWC has its own + # path that doesn't use _PACK_CACHE_CBLOCKED. + has_repack = ( + measure_repack and layout != "nhwc" and not _is_1x1_conv(R, S, dilation) + ) + if has_repack: + + def run_triton_e2e(): + _PACK_CACHE_CBLOCKED.clear() + return run_triton() + + ms_tri_e2e = triton.testing.do_bench(run_triton_e2e, warmup=15, rep=50) + else: + ms_tri_e2e = None + + flops = flops_conv(N, C, K, R, S, P, Q) + tflops_tri = flops / (ms_tri * 1e-3) / 1e12 + tflops_th = flops / (ms_th * 1e-3) / 1e12 + tflops_tri_e2e = flops / (ms_tri_e2e * 1e-3) / 1e12 if ms_tri_e2e else None + + return { + "ms_tri": ms_tri, + "ms_torch": ms_th, + "ms_tri_e2e": ms_tri_e2e, + "tflops_tri": tflops_tri, + "tflops_torch": tflops_th, + "tflops_tri_e2e": tflops_tri_e2e, + "correct": correct, + "kernel_name": kernel_name, + "has_repack": has_repack, + "flops": flops, + } + + +# ---------------------------------------------------------------------------- +# Single-shape mode (used by bench_models.py) +# ---------------------------------------------------------------------------- + + +def _format_single_shape_line(args, result: dict) -> str: + """Single-line key=value output for bench_models.py to parse. + + Last whitespace-separated token is the primary metric value. + """ + primary = result["ms_tri"] if args.metric == "time" else result["tflops_tri"] + parts = [ + f"N={args.N}", + f"C={args.C}", + f"H={args.H}", + f"W={args.W}", + f"K={args.K}", + f"R={args.R}", + f"S={args.S}", + f"method={args.method}", + f"layout={args.layout}", + f"ms_tri={result['ms_tri']:.4f}", + f"ms_torch={result['ms_torch']:.4f}", + f"tflops_tri={result['tflops_tri']:.4f}", + f"tflops_torch={result['tflops_torch']:.4f}", + f"correct={int(result['correct'])}", + ] + if args.show_kernel_name: + parts.append(f"kernel={result['kernel_name'] or 'unknown'}") + parts.append(f"{primary:.4f}") + return " ".join(parts) + + +def run_single_shape(args) -> None: + dtype = _torch_dtype(args.dtype) + stride = (args.stride_h, args.stride_w) + padding = (args.pad_h, args.pad_w) + dilation = (args.dilation_h, args.dilation_w) + # Single-shape mode (bench_models.py consumer): skip kernel+repack timing + # to keep per-call cost predictable for the framework. + result = bench_one_shape( + args.N, + args.C, + args.H, + args.W, + args.K, + args.R, + args.S, + stride, + padding, + dilation, + dtype, + args.method, + args.layout, + bias=not args.no_bias, + measure_repack=False, + ) + print(_format_single_shape_line(args, result)) + + +# ---------------------------------------------------------------------------- +# Box-drawn table printers +# ---------------------------------------------------------------------------- + + +def _box_table(headers, rows, align: Optional[list] = None) -> str: + """Render a list of header-string + row-tuples into a box-drawn table. + + align: per-column alignment, "l" (left, default) or "r" (right). + """ + n = len(headers) + if align is None: + align = ["l"] * n + widths = [ + max(len(headers[j]), max((len(str(r[j])) for r in rows), default=0)) + for j in range(n) + ] + + def fmt_row(vals): + cells = [] + for j, v in enumerate(vals): + s = str(v) + if align[j] == "r": + cells.append(f" {s:>{widths[j]}} ") + else: + cells.append(f" {s:<{widths[j]}} ") + return "│" + "│".join(cells) + "│" + + sep_top = "┌" + "┬".join("─" * (w + 2) for w in widths) + "┐" + sep_mid = "├" + "┼".join("─" * (w + 2) for w in widths) + "┤" + sep_bot = "└" + "┴".join("─" * (w + 2) for w in widths) + "┘" + + lines = [sep_top, fmt_row(headers), sep_mid] + for i, row in enumerate(rows): + lines.append(fmt_row(row)) + if i < len(rows) - 1: + lines.append(sep_mid) + lines.append(sep_bot) + return "\n".join(lines) + + +def _print_layer_table( + layers: list, has_any_repack: bool, miopen_enabled: bool +) -> None: + print("\n" + "=" * 80) + print("LAYER-BY-LAYER BENCHMARK") + print("=" * 80) + + headers = ["#", "Layer", "Type", "Shape"] + if miopen_enabled: + headers.append("MIOpen Solver") + headers.append("Triton Kernel") + headers.append("Tri Kernel TF/s") + if has_any_repack: + headers.append("Tri Kernel+Repack TF/s") + headers.extend(["Torch TF/s", "Winner"]) + + rows = [] + for i, lr in enumerate(layers): + row = [str(i), lr["name"], lr["type"], lr["shape"]] + if miopen_enabled: + row.append(lr["miopen_solver"] or "—") + row.append(lr["kernel_name"] or "—") + row.append(f"{lr['tflops_tri']:.2f}") + if has_any_repack: + row.append( + f"{lr['tflops_tri_e2e']:.2f}" + if lr["tflops_tri_e2e"] is not None + else "—" + ) + row.append(f"{lr['tflops_torch']:.2f}") + # Winner uses kernel TF/s (not kernel+repack) for consistency with old code. + row.append("Triton" if lr["tflops_tri"] > lr["tflops_torch"] else "Torch") + rows.append(row) + + print(_box_table(headers, rows)) + + +def _print_miopen_solver_table(layers: list) -> None: + """Group layers by MIOpen solver, print one row per solver.""" + from collections import OrderedDict + + solver_layers: dict = OrderedDict() + for i, lr in enumerate(layers): + s = lr.get("miopen_solver") or "unknown" + solver_layers.setdefault(s, []).append(f"L{i}") + + if not any(s != "unknown" for s in solver_layers): + return # Nothing detected; skip the table entirely. + + print("\n" + "=" * 80) + print("MIOpen SOLVER SUMMARY") + print("=" * 80) + + rows = [] + for solver, ls in solver_layers.items(): + algo = MIOPEN_ALGO_MAP.get(solver, solver) + layer_str = ", ".join(ls) + if len(layer_str) > 80: + layer_str = ", ".join(ls[:10]) + f" ... ({len(ls)} layers total)" + rows.append([solver, algo, layer_str]) + print(_box_table(("MIOpen Solver", "Algorithm Type", "Used For"), rows)) + + +def _print_overall_perf_table(layers: list, has_any_repack: bool) -> None: + """Mean/median/aggregate TFLOPS, total time, layer wins.""" + print("\n" + "=" * 80) + print("OVERALL PERFORMANCE") + print("=" * 80) + + tri_tf = [lr["tflops_tri"] for lr in layers] + th_tf = [lr["tflops_torch"] for lr in layers] + tri_ms = [lr["ms_tri"] for lr in layers] + th_ms = [lr["ms_torch"] for lr in layers] + + # Aggregate = sum(flops) / sum(time) + sum_flops = sum(lr["flops"] for lr in layers) + sum_time_tri = sum(lr["ms_tri"] * 1e-3 for lr in layers) + sum_time_th = sum(lr["ms_torch"] * 1e-3 for lr in layers) + agg_tri = sum_flops / sum_time_tri / 1e12 if sum_time_tri else 0.0 + agg_th = sum_flops / sum_time_th / 1e12 if sum_time_th else 0.0 + + n = len(layers) + tri_wins = sum(1 for lr in layers if lr["tflops_tri"] > lr["tflops_torch"]) + + rows = [ + [ + "Mean TFLOPS (kernel)", + f"{statistics.mean(tri_tf):.2f}", + f"{statistics.mean(th_tf):.2f}", + ], + ] + if has_any_repack: + e2e_tf = [ + ( + lr["tflops_tri_e2e"] + if lr["tflops_tri_e2e"] is not None + else lr["tflops_tri"] + ) + for lr in layers + ] + e2e_ms = [ + lr["ms_tri_e2e"] if lr["ms_tri_e2e"] is not None else lr["ms_tri"] + for lr in layers + ] + sum_time_e2e = sum(t * 1e-3 for t in e2e_ms) + agg_tri_e2e = sum_flops / sum_time_e2e / 1e12 if sum_time_e2e else 0.0 + e2e_wins = sum(1 for lr, ee in zip(layers, e2e_tf) if ee > lr["tflops_torch"]) + rows.append( + [ + "Mean TFLOPS (kernel+repack)", + f"{statistics.mean(e2e_tf):.2f}", + f"{statistics.mean(th_tf):.2f}", + ] + ) + rows.append( + [ + "Median TFLOPS (kernel)", + f"{statistics.median(tri_tf):.2f}", + f"{statistics.median(th_tf):.2f}", + ] + ) + if has_any_repack: + rows.append( + [ + "Median TFLOPS (kernel+repack)", + f"{statistics.median(e2e_tf):.2f}", + f"{statistics.median(th_tf):.2f}", + ] + ) + rows.append(["Aggregate TFLOPS (kernel)", f"{agg_tri:.2f}", f"{agg_th:.2f}"]) + if has_any_repack: + rows.append( + ["Aggregate TFLOPS (kernel+repack)", f"{agg_tri_e2e:.2f}", f"{agg_th:.2f}"] + ) + rows.append(["Total kernel time (ms)", f"{sum(tri_ms):.2f}", f"{sum(th_ms):.2f}"]) + if has_any_repack: + rows.append( + ["Total kernel+repack time (ms)", f"{sum(e2e_ms):.2f}", f"{sum(th_ms):.2f}"] + ) + rows.append(["Layer wins (kernel)", f"{tri_wins}/{n}", f"{n - tri_wins}/{n}"]) + if has_any_repack: + rows.append( + ["Layer wins (kernel+repack)", f"{e2e_wins}/{n}", f"{n - e2e_wins}/{n}"] + ) + rows.append( + ["Correctness", f"{sum(1 for lr in layers if lr['correct'])}/{n} passed", "—"] + ) + + print(_box_table(("Metric", "Triton", "PyTorch (MIOpen)"), rows)) + + +# ---------------------------------------------------------------------------- +# Sweep mode +# ---------------------------------------------------------------------------- + + +_MODEL_SHAPES_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "model_benchmarking_tool", + "model_shapes.json", +) + + +def _load_model_shapes(model_pattern: str) -> tuple[str, list]: + """Load conv2d shapes for a model from model_shapes.json. + + model_pattern: case-insensitive substring matched against model keys. + Returns (matched_model_name, list of shape tuples in the same form as + DEFAULT_SHAPES — desc is " L"). + """ + with open(_MODEL_SHAPES_PATH) as f: + data = json.load(f) + + matches = [ + m for m in data if model_pattern.lower() in m.lower() and "conv2d" in data[m] + ] + if not matches: + avail = sorted(m for m, k in data.items() if "conv2d" in k) + raise ValueError( + f"No model with 'conv2d' shapes matches {model_pattern!r}. " + f"Available: {avail}" + ) + if len(matches) > 1: + raise ValueError( + f"Pattern {model_pattern!r} matched multiple models: {matches}. " + f"Use a more specific pattern." + ) + + model = matches[0] + shapes = [] + for i, s in enumerate(data[model]["conv2d"]): + shapes.append( + ( + s["N"], + s["C"], + s["H"], + s["W"], + s["K"], + s["R"], + s["S"], + (s.get("stride_h", 1), s.get("stride_w", 1)), + (s.get("pad_h", 0), s.get("pad_w", 0)), + (s.get("dilation_h", 1), s.get("dilation_w", 1)), + f"{model} L{i}", + ) + ) + return model, shapes + + +def run_sweep(args) -> None: + """Iterate the chosen shape set, then print three summary tables.""" + dtype = _torch_dtype(args.dtype) + + if args.model: + try: + model, shapes = _load_model_shapes(args.model) + print( + f"# Sweep source: model_shapes.json :: {model} ({len(shapes)} layers)" + ) + except (FileNotFoundError, ValueError) as e: + print(f"ERROR: {e}", file=sys.stderr) + sys.exit(1) + else: + shapes = DEFAULT_SHAPES + print(f"# Sweep source: built-in DEFAULT_SHAPES ({len(shapes)} shapes)") + + print( + f"# dtype={args.dtype} method={args.method} layout={args.layout} " + f"miopen_solvers={'on' if args.miopen_solvers else 'off'}" + ) + + # Optional MIOpen solver detection (subprocess; ~60-120s startup) + if args.miopen_solvers: + print("# Detecting MIOpen solvers (subprocess; this can take a minute)...") + precompute_miopen_solvers(shapes, dtype) + print("# MIOpen solver detection complete.") + + # Bench each shape, collect rows. + layers = [] + for entry in shapes: + N, C, H, W, K, R, S, stride, padding, dilation, name = entry + try: + r = bench_one_shape( + N, + C, + H, + W, + K, + R, + S, + stride, + padding, + dilation, + dtype, + args.method, + args.layout, + bias=not args.no_bias, + measure_repack=True, + ) + except Exception as e: + print(f" {name:<24} ERROR: {type(e).__name__}: {e}", file=sys.stderr) + continue + miopen = ( + _get_miopen_solver(N, C, H, W, K, R, S, stride, padding, dilation) + if args.miopen_solvers + else "" + ) + layers.append( + { + "name": name, + "type": _kernel_type_tag(R, S, dilation), + "shape": _shape_str(N, C, H, W, K, R, S), + "kernel_name": r["kernel_name"], + "miopen_solver": miopen, + "tflops_tri": r["tflops_tri"], + "tflops_tri_e2e": r["tflops_tri_e2e"], + "tflops_torch": r["tflops_torch"], + "ms_tri": r["ms_tri"], + "ms_tri_e2e": r["ms_tri_e2e"], + "ms_torch": r["ms_torch"], + "correct": r["correct"], + "flops": r["flops"], + } + ) + + if not layers: + print("No layers benched (all errored?).", file=sys.stderr) + return + + has_any_repack = any(lr["ms_tri_e2e"] is not None for lr in layers) + + _print_layer_table(layers, has_any_repack, miopen_enabled=args.miopen_solvers) + if args.miopen_solvers: + _print_miopen_solver_table(layers) + _print_overall_perf_table(layers, has_any_repack) + + +# ---------------------------------------------------------------------------- +# CLI +# ---------------------------------------------------------------------------- + + +def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: + p = argparse.ArgumentParser( + prog="bench_conv2d", + description="Benchmark aiter.ops.triton.conv.conv2d (single shape or sweep).", + allow_abbrev=False, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--dtype", + "--conv_dtype", + "--conv-dtype", + choices=["fp16", "bf16"], + default="fp16", + ) + p.add_argument( + "--method", + choices=list(METHODS.keys()), + default="auto", + help="kernel to bench. 'auto' uses the conv2d router.", + ) + p.add_argument( + "--layout", + "--conv_layout", + "--conv-layout", + choices=["nchw", "nhwc"], + default="nchw", + ) + p.add_argument("--metric", choices=["time", "throughput"], default="throughput") + p.add_argument( + "--no-bias", + "--no_bias", + action="store_true", + help="bench the bias=None code path", + ) + p.add_argument( + "--show-kernel-name", + "--show_kernel_name", + action="store_true", + help="include the routed Triton kernel name in single-shape output", + ) + p.add_argument( + "--miopen-solvers", + "--miopen_solvers", + action="store_true", + help="detect MIOpen solver names via a subprocess (sweep mode only; " + "adds ~60-120s upfront cost)", + ) + p.add_argument( + "--model", + type=str, + default=None, + help="sweep mode: load conv2d shapes for this model from " + "model_shapes.json (case-insensitive substring match). If omitted, " + "the built-in DEFAULT_SHAPES list is used.", + ) + + # Single-shape mode (used by bench_models.py and one-off measurements). + p.add_argument("--N", type=int, default=None) + p.add_argument("--C", type=int, default=None) + p.add_argument("--H", type=int, default=None) + p.add_argument("--W", type=int, default=None) + p.add_argument("--K", type=int, default=None) + p.add_argument("--R", type=int, default=None) + p.add_argument("--S", type=int, default=None) + p.add_argument("--stride-h", "--stride_h", type=int, default=1) + p.add_argument("--stride-w", "--stride_w", type=int, default=1) + p.add_argument("--pad-h", "--pad_h", type=int, default=0) + p.add_argument("--pad-w", "--pad_w", type=int, default=0) + p.add_argument("--dilation-h", "--dilation_h", type=int, default=1) + p.add_argument("--dilation-w", "--dilation_w", type=int, default=1) + + args = p.parse_args(argv) + + single = [args.N, args.C, args.H, args.W, args.K, args.R, args.S] + if any(v is not None for v in single): + if any(v is None for v in single): + p.error("single-shape mode requires all of --N --C --H --W --K --R --S") + args.single_shape = True + else: + args.single_shape = False + return args + + +def main(argv: Optional[list[str]] = None) -> None: + args = parse_args(argv) + if args.single_shape: + run_single_shape(args) + else: + run_sweep(args) + + +if __name__ == "__main__": + main() diff --git a/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py b/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py new file mode 100644 index 0000000000..899deefae1 --- /dev/null +++ b/op_tests/op_benchmarks/triton/model_benchmarking_tool/extract_conv_shapes.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +"""Offline extraction of Conv2d layer shapes from real model checkpoints. + +Run once per model. Walks the model with forward hooks, captures the +per-layer (N, C, H, W, K, R, S, stride, pad, dilation) tuples, dedupes, +and prints JSON ready to paste under ``"": {"conv2d": [...]}`` +in ``model_shapes.json``. + +Consumed by ``bench_conv2d.py --model NAME`` (sweep mode). + +Usage:: + + python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\ + --model resnet50 + + python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\ + --model sd35_vae --model-path /app/models/stable-diffusion-3.5-medium + + python -m op_tests.op_benchmarks.triton.model_benchmarking_tool.extract_conv_shapes \\ + --model flux2_vae --model-path /app/models/FLUX.2-klein-9B + +Weight values do not affect the captured shapes (Conv2d dimensions are +fixed by the layer definition, not the weight tensor contents). Concretely: +- resnet50 uses ``weights=None`` (random init, no checkpoint download). +- sd35_vae / flux2_vae use ``AutoencoderKL[Flux2].from_pretrained(...)``, + which loads the local checkpoint at ``--model-path`` (architecture + + weights). The weights are unused for shape extraction; they come along + because diffusers' easiest API loads both. + +Skips ``groups != 1`` layers (kernel doesn't support grouped/depthwise yet) +and ``padding_mode != 'zeros'`` layers. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from typing import Optional + +import torch +import torch.nn as nn + + +def _conv_shape_dict(mod: nn.Conv2d, x_shape: torch.Size) -> dict: + _, _, H, W = x_shape + R, S = ( + mod.kernel_size + if isinstance(mod.kernel_size, tuple) + else (mod.kernel_size, mod.kernel_size) + ) + sh, sw = mod.stride if isinstance(mod.stride, tuple) else (mod.stride, mod.stride) + ph, pw = ( + mod.padding if isinstance(mod.padding, tuple) else (mod.padding, mod.padding) + ) + dh, dw = ( + mod.dilation + if isinstance(mod.dilation, tuple) + else (mod.dilation, mod.dilation) + ) + return { + "N": int(x_shape[0]), + "C": int(mod.in_channels), + "H": int(H), + "W": int(W), + "K": int(mod.out_channels), + "R": int(R), + "S": int(S), + "stride_h": int(sh), + "stride_w": int(sw), + "pad_h": int(ph), + "pad_w": int(pw), + "dilation_h": int(dh), + "dilation_w": int(dw), + } + + +def _walk(module: nn.Module, run_forward, batch_size: int) -> list[dict]: + """Hook every Conv2d, run forward once, return deduped shape dicts.""" + captured: dict[str, torch.Size] = {} + + def make_hook(key): + def hook(m, inp, out): + if key not in captured and inp and isinstance(inp[0], torch.Tensor): + captured[key] = inp[0].shape + + return hook + + handles = [] + for name, mod in module.named_modules(): + if isinstance(mod, nn.Conv2d): + handles.append(mod.register_forward_hook(make_hook(name))) + try: + with torch.no_grad(): + run_forward() + finally: + for h in handles: + h.remove() + + shapes = [] + skipped_groups = 0 + skipped_padmode = 0 + skipped_no_shape = 0 + for name, mod in module.named_modules(): + if not isinstance(mod, nn.Conv2d): + continue + if mod.groups != 1: + skipped_groups += 1 + continue + if getattr(mod, "padding_mode", "zeros") != "zeros": + skipped_padmode += 1 + continue + if name not in captured: + skipped_no_shape += 1 + continue + shapes.append(_conv_shape_dict(mod, captured[name])) + + print( + f"# captured {len(shapes)} conv layers " + f"(skipped: {skipped_groups} grouped, {skipped_padmode} non-zero pad-mode, " + f"{skipped_no_shape} unreached)", + file=sys.stderr, + ) + + # Dedupe — many models have the same shape repeated across blocks. + seen = set() + deduped = [] + for s in shapes: + key = tuple(s.items()) + if key in seen: + continue + seen.add(key) + deduped.append(s) + print(f"# {len(deduped)} unique shapes after dedupe", file=sys.stderr) + return deduped + + +# Hardcoded internally — these don't affect captured shape values, only the +# forward-pass runtime. fp16 + cuda is fine for every model we extract from. +_DTYPE = torch.float16 +_DEVICE = "cuda" + + +def extract_resnet50(N: int, H: int, W: int) -> list[dict]: + from torchvision.models import resnet50 + + model = resnet50(weights=None).to(device=_DEVICE, dtype=_DTYPE).eval() + x = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE) + + def fwd(): + model(x) + + return _walk(model, fwd, N) + + +def extract_sd35_vae(model_path: str, N: int, H: int, W: int) -> list[dict]: + from diffusers import AutoencoderKL + + sub = model_path if model_path.rstrip("/").endswith("/vae") else model_path + "/vae" + vae = AutoencoderKL.from_pretrained(sub, torch_dtype=_DTYPE).to(_DEVICE).eval() + img = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE) + + def fwd(): + latent = vae.encode(img).latent_dist.sample() + vae.decode(latent) + + return _walk(vae, fwd, N) + + +def extract_flux2_vae(model_path: str, N: int, H: int, W: int) -> list[dict]: + from diffusers import AutoencoderKLFlux2 + + sub = model_path if model_path.rstrip("/").endswith("/vae") else model_path + "/vae" + vae = AutoencoderKLFlux2.from_pretrained(sub, torch_dtype=_DTYPE).to(_DEVICE).eval() + img = torch.randn(N, 3, H, W, device=_DEVICE, dtype=_DTYPE) + + def fwd(): + latent = vae.encode(img).latent_dist.sample() + vae.decode(latent) + + return _walk(vae, fwd, N) + + +def main(argv: Optional[list[str]] = None) -> None: + p = argparse.ArgumentParser(prog="extract_conv_shapes", description=__doc__) + p.add_argument( + "--model", required=True, choices=["resnet50", "sd35_vae", "flux2_vae"] + ) + p.add_argument( + "--model-path", + default=None, + help="local checkpoint dir (required for sd35_vae / flux2_vae)", + ) + p.add_argument("--N", type=int, default=1, help="batch size") + p.add_argument( + "--H", type=int, default=None, help="input H (default: model-typical)" + ) + p.add_argument( + "--W", type=int, default=None, help="input W (default: model-typical)" + ) + args = p.parse_args(argv) + + if args.model == "resnet50": + H = args.H if args.H is not None else 224 + W = args.W if args.W is not None else 224 + shapes = extract_resnet50(args.N, H, W) + model_key = "resnet50" + elif args.model == "sd35_vae": + if not args.model_path: + p.error("--model-path required for sd35_vae") + H = args.H if args.H is not None else 512 + W = args.W if args.W is not None else 512 + shapes = extract_sd35_vae(args.model_path, args.N, H, W) + model_key = "stable-diffusion-3.5-medium" + elif args.model == "flux2_vae": + if not args.model_path: + p.error("--model-path required for flux2_vae") + H = args.H if args.H is not None else 512 + W = args.W if args.W is not None else 512 + shapes = extract_flux2_vae(args.model_path, args.N, H, W) + model_key = "FLUX.2-klein-9B" + else: + raise AssertionError("unreachable") + + print(json.dumps({model_key: {"conv2d": shapes}}, indent=4)) + + +if __name__ == "__main__": + main() diff --git a/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json b/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json index 0fd7d93f6a..afa49c92ef 100644 --- a/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json +++ b/op_tests/op_benchmarks/triton/model_benchmarking_tool/model_shapes.json @@ -44,12 +44,12 @@ "TP_dim": "K" } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 16384 } ], - "rope": [ + "rope": [ { "num_heads": 128, "num_kv_heads": 8, @@ -121,12 +121,12 @@ "TP_dim": "K" } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 8192 } ], - "rope": [ + "rope": [ { "num_heads": 64, "num_kv_heads": 8, @@ -198,12 +198,12 @@ "TP_dim": "K" } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 4096 } ], - "rope": [ + "rope": [ { "num_heads": 32, "num_kv_heads": 8, @@ -256,12 +256,12 @@ "TopK": 4 } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 2880 } ], - "rope": [ + "rope": [ { "num_heads": 64, "num_kv_heads": 8, @@ -441,7 +441,7 @@ "TopK": 8 } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 7168 }, @@ -452,7 +452,7 @@ "N": 512 } ], - "rope": [ + "rope": [ { "num_heads": 128, "num_kv_heads": 128, @@ -537,12 +537,12 @@ "TopK": 1 } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 5120 } ], - "rope": [ + "rope": [ { "num_heads": 40, "num_kv_heads": 8, @@ -605,7 +605,7 @@ "TopK": 8 } ], - "rmsnorm": [ + "rmsnorm": [ { "N": 4096 }, @@ -613,7 +613,7 @@ "N": 128 } ], - "rope": [ + "rope": [ { "num_heads": 64, "num_kv_heads": 4, @@ -639,5 +639,1022 @@ "dv": 128 } ] + }, + "resnet50": { + "conv2d": [ + { + "N": 1, + "C": 3, + "H": 224, + "W": 224, + "K": 64, + "R": 7, + "S": 7, + "stride_h": 2, + "stride_w": 2, + "pad_h": 3, + "pad_w": 3, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 64, + "H": 56, + "W": 56, + "K": 64, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 64, + "H": 56, + "W": 56, + "K": 64, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 64, + "H": 56, + "W": 56, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 56, + "W": 56, + "K": 64, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 56, + "W": 56, + "K": 128, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 56, + "W": 56, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 28, + "W": 28, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 56, + "W": 56, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 28, + "W": 28, + "K": 128, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 28, + "W": 28, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 28, + "W": 28, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 28, + "W": 28, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 14, + "W": 14, + "K": 1024, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 28, + "W": 28, + "K": 1024, + "R": 1, + "S": 1, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 1024, + "H": 14, + "W": 14, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 14, + "W": 14, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 1024, + "H": 14, + "W": 14, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 14, + "W": 14, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 7, + "W": 7, + "K": 2048, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 1024, + "H": 14, + "W": 14, + "K": 2048, + "R": 1, + "S": 1, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 2048, + "H": 7, + "W": 7, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 7, + "W": 7, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + } + ] + }, + "stable-diffusion-3.5-medium": { + "conv2d": [ + { + "N": 1, + "C": 3, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 513, + "W": 513, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 256, + "W": 256, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 257, + "W": 257, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 128, + "W": 128, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 128, + "W": 128, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 128, + "W": 128, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 129, + "W": 129, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 64, + "W": 64, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 64, + "W": 64, + "K": 32, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 16, + "H": 64, + "W": 64, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 128, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 512, + "W": 512, + "K": 3, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + } + ] + }, + "FLUX.2-klein-9B": { + "conv2d": [ + { + "N": 1, + "C": 3, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 513, + "W": 513, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 256, + "W": 256, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 257, + "W": 257, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 128, + "W": 128, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 128, + "W": 128, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 128, + "W": 128, + "K": 512, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 129, + "W": 129, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 2, + "stride_w": 2, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 64, + "W": 64, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 64, + "W": 64, + "K": 64, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 32, + "H": 64, + "W": 64, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 512, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 512, + "H": 256, + "W": 256, + "K": 256, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 256, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 128, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 256, + "H": 512, + "W": 512, + "K": 128, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 128, + "H": 512, + "W": 512, + "K": 3, + "R": 3, + "S": 3, + "stride_h": 1, + "stride_w": 1, + "pad_h": 1, + "pad_w": 1, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 64, + "H": 64, + "W": 64, + "K": 64, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + }, + { + "N": 1, + "C": 32, + "H": 64, + "W": 64, + "K": 32, + "R": 1, + "S": 1, + "stride_h": 1, + "stride_w": 1, + "pad_h": 0, + "pad_w": 0, + "dilation_h": 1, + "dilation_w": 1 + } + ] } -} \ No newline at end of file +} diff --git a/op_tests/triton_tests/conv/__init__.py b/op_tests/triton_tests/conv/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/op_tests/triton_tests/conv/_helpers.py b/op_tests/triton_tests/conv/_helpers.py new file mode 100644 index 0000000000..2b1bfda14d --- /dev/null +++ b/op_tests/triton_tests/conv/_helpers.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +"""Test-side library: TestSuite, method registry, and runners. + +Library code (no ``test_`` prefix) — pytest does not collect this file. +``test_conv2d.py`` imports the runners and registry from here. + +Public surface: + +- ``TestSuite`` / ``TestResult`` + Correctness collector. ``check_close`` records pass/fail per case; + ``failed_results`` returns the list at end of run. + +- ``METHOD_REGISTRY`` (+ ``ORDERED_METHODS`` / ``ALL_METHODS``) + Kernel dispatch table. Each entry maps a method name to its + public ``conv2d_*`` callable, applicability guard, winograd flag, + and bench tag. + +- ``run_all_methods(...)`` + Main dispatch. For a given (x, w, b, stride, padding, dilation): + runs the selected NCHW kernel (or all of them, if method="all"), + runs ``conv2d_nhwc`` if layout_mode includes nhwc, and checks every + output against ``F.conv2d`` within method-appropriate tolerance. + +- ``run_edge_cases``, ``run_random_fuzzing``, ``run_no_bias``, + ``run_activations``, ``run_cross_method`` + Test runners called by ``test_conv2d.py``. Each iterates a shape + set and delegates to ``run_all_methods``. + +- ``COMMON_SHAPES``, ``get_edge_case_shapes()`` + Shared shape data — 3x3 stride-1 shapes routable by every kernel, + and the 12-shape edge-case list respectively. +""" + +from __future__ import annotations +import random +import traceback +from collections import namedtuple +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from aiter.ops.triton.conv._utils import ( + dynamic_conv_tolerances, + _out_hw, + _is_1x1_conv, + _is_3x3_conv, + _winograd_tolerances, + apply_activation, +) +from aiter.ops.triton.conv._launch import _select_3x3_method +from aiter.ops.triton.conv.conv2d import ( + conv2d_nchw, + conv2d_nchw_cblocked, + conv2d_nhwc, + conv2d_winograd_f4x3, + conv2d_winograd_f4x3_fused, + conv2d_winograd_f4x3_cblocked, +) + +# -- Architecture gating ------------------------------------------------------ +SUPPORTED_ARCHS = { + "RDNA": {"gfx1200", "gfx1201"}, + "CDNA": set(), +} + +# Flat union for arch-check use sites. +ALL_SUPPORTED_ARCHS = set().union(*SUPPORTED_ARCHS.values()) + + +# -- Method registry ---------------------------------------------------------- + +MethodEntry = namedtuple( + "MethodEntry", ["kernel_fn", "guard_fn", "is_winograd", "bench_tag", "short_name"] +) + + +def _3x3_guard(R, S, stride, dilation, C): + return _is_3x3_conv(R, S) + + +def _wino_guard(R, S, stride, dilation, C): + # _is_winograd_eligible signature varies by upstream — keep the flag tight + from aiter.ops.triton.conv._utils import _is_winograd_eligible + + return _is_winograd_eligible(R, S, stride, dilation, C) + + +METHOD_REGISTRY = { + "default": MethodEntry(conv2d_nchw, None, False, "", "default"), + "cblocked": MethodEntry( + conv2d_nchw_cblocked, _3x3_guard, False, "[cblocked]", "cblocked" + ), + "winograd_f4x3": MethodEntry( + conv2d_winograd_f4x3, _wino_guard, True, "[winograd_f4x3]", "WF(4,3)" + ), + "winograd_f4x3_fused": MethodEntry( + conv2d_winograd_f4x3_fused, _wino_guard, True, "[wino_f4x3_fused]", "WF4fused" + ), + "winograd_f4x3_cblocked": MethodEntry( + conv2d_winograd_f4x3_cblocked, + _wino_guard, + True, + "[winograd_f4x3_cblocked]", + "WF4cb", + ), +} + +ORDERED_METHODS = list(METHOD_REGISTRY.keys()) +ALL_METHODS = ORDERED_METHODS + ["all"] + + +# -- Result + suite ----------------------------------------------------------- + + +@dataclass +class TestResult: + name: str + passed: bool + max_abs_error: float + rel_error: float + message: str = "" + + +class TestSuite: + """Correctness-only test runner. No bench records, no MIOpen tables.""" + + __test__ = False # not a pytest TestCase — leading "Test" is incidental + + def __init__( + self, + device: str, + dtype: torch.dtype, + verbose: bool = False, + print_shapes: bool = False, + layout_mode: str = "both", + ): + self.device = torch.device(device) + self.dtype = dtype + self.verbose = verbose + self.print_shapes = print_shapes + self.layout_mode = layout_mode + self.results: List[TestResult] = [] + + def check_close( + self, + name: str, + got: torch.Tensor, + ref: torch.Tensor, + K_red: Optional[int] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> TestResult: + got32 = got.float() + ref32 = ref.float() + diff = (got32 - ref32).abs() + max_abs = float(diff.max().item()) if diff.numel() else 0.0 + rel = max_abs / (float(ref32.abs().max().item()) + 1e-6) + if rtol is None or atol is None: + K_est = int(K_red) if K_red is not None else 1024 + rtol_calc, atol_calc = dynamic_conv_tolerances(self.dtype, K_est, ref32) + rtol = rtol if rtol is not None else rtol_calc + atol = atol if atol is not None else atol_calc + try: + torch.testing.assert_close(got32, ref32, rtol=rtol, atol=atol) + passed = True + msg = "OK" + except AssertionError as e: + passed = False + msg = str(e).split("\n")[0] + res = TestResult(name, passed, max_abs, rel, msg) + self.results.append(res) + if self.verbose: + mark = "✓" if passed else "✗" + print(f" {mark} {name:<40} | max_abs={max_abs:.3e} rel={rel:.3e}") + return res + + def all_passed(self) -> bool: + return all(r.passed for r in self.results) + + def failed_results(self) -> List[TestResult]: + return [r for r in self.results if not r.passed] + + +# -- Tolerance + dispatch ----------------------------------------------------- + + +def _get_tolerances( + method_name, entry, suite, y_ref, N, C, H, W, K_out, R, S, stride, dilation +): + if entry.is_winograd: + return _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3") + if method_name == "default" and _is_3x3_conv(R, S): + routed = _select_3x3_method(N, C, H, W, K_out, stride, dilation) + if routed and "winograd" in routed: + return _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3") + return dynamic_conv_tolerances(suite.dtype, C * R * S, y_ref) + + +def run_all_methods( + suite: TestSuite, + x: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor], + stride, + padding, + dilation, + name: str, + method: str = "default", + activation: str = "none", +): + """Correctness-only dispatch: run selected method(s) and check vs F.conv2d.""" + N, C, H, W_in = x.shape + K_out, _, R, S = w.shape + + y_ref = F.conv2d( + x, + w, + b.to(dtype=suite.dtype) if b is not None else None, + stride=stride, + padding=padding, + dilation=dilation, + ) + y_ref = apply_activation(y_ref, activation) + + if suite.print_shapes: + if _is_1x1_conv(R, S, dilation): + kernel_type = "[1x1]" + elif _is_3x3_conv(R, S): + kernel_type = "[3x3]" + else: + kernel_type = "[general]" + print( + f" {name} {kernel_type}: X{tuple(x.shape)} W{tuple(w.shape)} -> Y{tuple(y_ref.shape)}" + ) + + if suite.layout_mode in ("nchw", "both"): + methods_to_run = ORDERED_METHODS if method == "all" else [method] + for m in methods_to_run: + entry = METHOD_REGISTRY[m] + if entry.guard_fn and not entry.guard_fn(R, S, stride, dilation, C): + continue + y_tri = entry.kernel_fn( + x, + w, + b, + stride, + padding, + dilation, + activation=activation, + out_dtype=suite.dtype, + ) + rtol, atol = _get_tolerances( + m, entry, suite, y_ref, N, C, H, W_in, K_out, R, S, stride, dilation + ) + suite.check_close( + f"{name} {entry.bench_tag or '[NCHW]'}", + y_tri, + y_ref, + rtol=rtol, + atol=atol, + ) + + if suite.layout_mode in ("nhwc", "both"): + y_nhwc = conv2d_nhwc( + x, + w, + b, + stride, + padding, + dilation, + activation=activation, + out_dtype=suite.dtype, + ) + if _is_3x3_conv(R, S): + nhwc_method = _select_3x3_method(N, C, H, W_in, K_out, stride, dilation) + if nhwc_method in ("winograd_f4x3", "winograd_f4x3_cblocked"): + _r, _a = _winograd_tolerances(suite.dtype, C * R * S, y_ref, "f4x3") + suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, rtol=_r, atol=_a) + else: + suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, K_red=C * R * S) + else: + suite.check_close(f"{name} [NHWC]", y_nhwc, y_ref, K_red=C * R * S) + + +# -- Shape sets --------------------------------------------------------------- + +# Shapes routable by ALL 5 NCHW kernels (3x3, stride=1, padding=1, dilation=1, +# C >= 4). Used by test_cross_method to verify every kernel produces the same +# result (within tolerance) on the same input. +COMMON_SHAPES = [ + (1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), (1, 1), "common 64ch/56sp"), + (1, 128, 28, 28, 128, 3, 3, (1, 1), (1, 1), (1, 1), "common 128ch/28sp"), + (1, 256, 14, 14, 256, 3, 3, (1, 1), (1, 1), (1, 1), "common 256ch/14sp"), +] + + +# -- Edge case shapes --------------------------------------------------------- + + +def get_edge_case_shapes(): + return [ + (1, 3, 7, 7, 8, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 same padding"), + (1, 3, 8, 8, 16, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 stride1"), + (2, 16, 32, 32, 32, 3, 3, (2, 2), (1, 1), (1, 1), "stride2"), + (2, 32, 17, 23, 64, 5, 5, (2, 2), (2, 2), (1, 1), "odd dims + pad"), + (4, 64, 28, 28, 128, 3, 3, (1, 1), (0, 0), (2, 2), "dilation2"), + (2, 512, 7, 7, 1024, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 large channels"), + (1, 3, 112, 112, 64, 7, 7, (2, 2), (3, 3), (1, 1), "7x7 large spatial"), + (1, 1, 16, 16, 16, 3, 3, (1, 1), (1, 1), (1, 1), "single input channel"), + (2, 64, 8, 8, 64, 3, 3, (1, 1), (1, 1), (1, 1), "small spatial 3x3"), + (1, 128, 4, 4, 256, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 tiny spatial"), + (2, 32, 32, 32, 32, 3, 3, (1, 1), (0, 0), (1, 1), "3x3 no padding"), + (2, 64, 28, 28, 128, 3, 3, (2, 2), (1, 1), (1, 1), "3x3 stride2 standard"), + ] + + +# -- Test runners (no test_ prefix; pytest will not collect this file) -------- + + +def run_edge_cases(suite: TestSuite, activation: str = "none", method: str = "default"): + for ( + N, + C, + H, + W, + K_out, + R, + S, + stride, + padding, + dilation, + desc, + ) in get_edge_case_shapes(): + P, Q = _out_hw(H, W, R, S, stride, padding, dilation) + if P < 1 or Q < 1: + continue + x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype) + w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype) + b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype) + run_all_methods( + suite, + x, + w, + b, + stride, + padding, + dilation, + name=desc, + method=method, + activation=activation, + ) + + +def run_activations( + suite: TestSuite, method: str = "default", activation: str = "relu" +): + N, C, H, W, K_out = 2, 32, 16, 16, 64 + R, S = 3, 3 + stride, padding, dilation = (1, 1), (1, 1), (1, 1) + x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype) + w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype) + b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype) + run_all_methods( + suite, + x, + w, + b, + stride, + padding, + dilation, + name=f"activation_{activation}_{method}", + method=method, + activation=activation, + ) + + +def run_no_bias(suite: TestSuite, method: str = "default"): + shapes = [ + (1, 64, 8, 8, 128, 1, 1, (1, 1), (0, 0), (1, 1), "1x1 no bias"), + (2, 32, 16, 16, 64, 3, 3, (1, 1), (1, 1), (1, 1), "3x3 no bias"), + (1, 16, 8, 8, 32, 5, 5, (1, 1), (2, 2), (1, 1), "5x5 no bias"), + ] + for N, C, H, W, K_out, R, S, stride, padding, dilation, desc in shapes: + x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype) + w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype) + run_all_methods( + suite, + x, + w, + None, + stride, + padding, + dilation, + name=desc, + method=method, + ) + + +def run_cross_method(suite: TestSuite): + """Run every NCHW-applicable kernel on shapes that all 5 can handle. + + Each kernel is checked against F.conv2d. Transitivity gives us + cross-kernel equivalence: if kernel A and B both match the same + F.conv2d output within tolerance, they match each other within ~2x. + """ + for N, C, H, W, K_out, R, S, stride, padding, dilation, desc in COMMON_SHAPES: + x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype) + w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype) + b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype) + run_all_methods( + suite, + x, + w, + b, + stride, + padding, + dilation, + name=desc, + method="all", # iterates every kernel in ORDERED_METHODS + ) + + +def run_random_fuzzing( + suite: TestSuite, + num_tests: int = 10, + activation: str = "none", + method: str = "default", + seed: int = 42, +): + """Bounded random shape sweep, seeded for reproducibility. + + Default num_tests=10 keeps CI cheap; callers can pass a larger value + for ad-hoc development sweeps. + """ + random.seed(seed) + for i in range(num_tests): + N = random.randint(1, 8) + C = random.choice([1, 3, 16, 32, 64, 128, 256]) + H = random.randint(4, 64) + W = random.randint(4, 64) + K_out = random.choice([16, 32, 64, 128, 256]) + R = random.randint(1, min(7, H)) + S = random.randint(1, min(7, W)) + sh = random.randint(1, 3) + sw = random.randint(1, 3) + ph = random.randint(0, R // 2) + pw = random.randint(0, S // 2) + dh = random.randint(1, 2) + dw = random.randint(1, 2) + P, Q = _out_hw(H, W, R, S, (sh, sw), (ph, pw), (dh, dw)) + if P < 1 or Q < 1: + continue + try: + x = torch.randn((N, C, H, W), device=suite.device, dtype=suite.dtype) + w = torch.randn((K_out, C, R, S), device=suite.device, dtype=suite.dtype) + b = torch.randn((K_out,), device=suite.device, dtype=suite.dtype) + tag = f"Random[{i}] ({N},{C},{H},{W})->({N},{K_out},{P},{Q})" + run_all_methods( + suite, + x, + w, + b, + (sh, sw), + (ph, pw), + (dh, dw), + name=tag, + method=method, + activation=activation, + ) + except Exception as e: + tb = traceback.format_exc() + suite.results.append( + TestResult( + f"Random[{i}]", + False, + float("inf"), + float("inf"), + f"{type(e).__name__}: {e}\n{tb}", + ) + ) diff --git a/op_tests/triton_tests/conv/test_conv2d.py b/op_tests/triton_tests/conv/test_conv2d.py new file mode 100644 index 0000000000..c129e5f39b --- /dev/null +++ b/op_tests/triton_tests/conv/test_conv2d.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +"""Pytest unit tests for aiter.ops.triton.conv.conv2d. + +Correctness only. All tests compare Triton kernels against +torch.nn.functional.conv2d on synthetic tensors. No model loading, +no network, no torchvision. + +Test matrix (uniform across the four primary test families): + + NCHW × {fp16, bf16} × every kernel (5 kernels) = 10 + NHWC × {fp16, bf16} (single dispatch) = 2 + --- + base cases per test family 12 + +test_edge, test_fuzz, test_no_bias use the base matrix as-is. +test_activations multiplies the base matrix by 3 (relu/relu6/gelu) -> 36. + +Plus test_cross_method (differential correctness) that runs every NCHW +kernel on shapes routable by all of them and verifies they all match +F.conv2d. NCHW-only by design; 2 cases (one per dtype). + +Total: 12 + 12 + 12 + 36 + 2 = 74 cases. + +Where a kernel's guard rejects a shape (e.g. winograd on a 5x5), the +shape is silently skipped inside run_all_methods. + +Performance benchmarking lives in +op_tests/op_benchmarks/triton/bench_conv2d.py (and, for real-model +shapes, in op_benchmarks/triton/model_benchmarking_tool/bench_models.py). +""" + +from __future__ import annotations + +import pytest +import torch + +from aiter.ops.triton.utils._triton.arch_info import get_arch + +from ._helpers import ( + ALL_SUPPORTED_ARCHS, + TestSuite, + ORDERED_METHODS, + run_edge_cases, + run_activations, + run_no_bias, + run_random_fuzzing, + run_cross_method, +) + +# Module-level arch gate. Skip the whole test module on unsupported archs +# rather than fail per-test. Extend SUPPORTED_ARCHS in _helpers.py when +# adding CDNA (or other RDNA) support. +_current_arch = get_arch() +if _current_arch not in ALL_SUPPORTED_ARCHS: + pytest.skip( + f"aiter.ops.triton.conv tests run on {sorted(ALL_SUPPORTED_ARCHS)}; " + f"current arch {_current_arch!r} not supported", + allow_module_level=True, + ) + + +# Build the (dtype, layout, method) matrix once. NHWC entries only pair with +# method="default" because conv2d_nhwc is single-dispatch — the method param +# is a no-op there, so re-running for every method id would just duplicate work. +def _build_matrix(): + matrix = [] + for dtype, dtype_id in [(torch.float16, "fp16"), (torch.bfloat16, "bf16")]: + for method in ORDERED_METHODS: + matrix.append(((dtype, "nchw", method), f"{dtype_id}_nchw_{method}")) + matrix.append(((dtype, "nhwc", "default"), f"{dtype_id}_nhwc")) + return matrix + + +_MATRIX = _build_matrix() +PARAMS = [params for params, _ in _MATRIX] +IDS = [tid for _, tid in _MATRIX] + + +def _make_suite(dtype, layout): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return TestSuite(device="cuda", dtype=dtype, layout_mode=layout) + + +def _assert_suite(suite: TestSuite): + failed = suite.failed_results() + assert not failed, f"{len(failed)} tests failed: {[r.name for r in failed]}" + + +# -- The four primary test families, all on the same matrix ------------------ + + +@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS) +def test_edge(dtype, layout, method): + suite = _make_suite(dtype, layout) + run_edge_cases(suite, method=method) + _assert_suite(suite) + + +@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS) +def test_fuzz(dtype, layout, method): + suite = _make_suite(dtype, layout) + run_random_fuzzing(suite, num_tests=10, method=method) + _assert_suite(suite) + + +@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS) +def test_no_bias(dtype, layout, method): + suite = _make_suite(dtype, layout) + run_no_bias(suite, method=method) + _assert_suite(suite) + + +@pytest.mark.parametrize("activation", ["relu", "relu6", "gelu"]) +@pytest.mark.parametrize("dtype,layout,method", PARAMS, ids=IDS) +def test_activations(dtype, layout, method, activation): + suite = _make_suite(dtype, layout) + run_activations(suite, method=method, activation=activation) + _assert_suite(suite) + + +# -- Differential correctness across all 5 NCHW kernels (NCHW-only) ---------- + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) +def test_cross_method(dtype): + suite = _make_suite(dtype, "nchw") + run_cross_method(suite) + _assert_suite(suite)