diff --git a/python/triton_kernels/bench/bench_reduce.py b/python/triton_kernels/bench/bench_reduce.py new file mode 100644 index 000000000000..bb6a2eb20689 --- /dev/null +++ b/python/triton_kernels/bench/bench_reduce.py @@ -0,0 +1,71 @@ +import argparse +import statistics + +import torch + +from triton_kernels.reduce import reduce, _select_reduce_forward_config + + +def _csv_ints(s): + return [int(x) for x in s.split(",") if x] + + +def _flush_cache(cache_killer): + if cache_killer is not None: + cache_killer.add_(1.0) + + +def bench_reduce(k, s0, s1, iters, cache_killer): + x = torch.randn((k, s0, s1), device="cuda", dtype=torch.float32) + for _ in range(10): + _flush_cache(cache_killer) + reduce(x, dim=0, y_dtype=torch.bfloat16) + torch.cuda.synchronize() + + times_ms = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for _ in range(iters): + _flush_cache(cache_killer) + start.record() + reduce(x, dim=0, y_dtype=torch.bfloat16) + end.record() + torch.cuda.synchronize() + times_ms.append(start.elapsed_time(end)) + times_ms.sort() + return statistics.median(times_ms), statistics.mean(times_ms), times_ms[int(0.9 * (iters - 1))] + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark wide-S1 reduce_forward shapes.") + parser.add_argument("--ks", default="1,2,3,4,5,6,7,8") + parser.add_argument("--s0s", default="1,2,4,8,16,32,64,128,256") + parser.add_argument("--s1s", default="1024,2048,4096,8192,16384,32768") + parser.add_argument("--iters", type=int, default=80) + parser.add_argument( + "--flush-mb", + type=int, + default=512, + help="Touch this many MiB before each measured reduce. Set to 0 to benchmark hot-cache repeats.", + ) + args = parser.parse_args() + + cache_killer = None + if args.flush_mb > 0: + n_elements = args.flush_mb * 1024 * 1024 // torch.empty((), dtype=torch.float32).element_size() + cache_killer = torch.empty(n_elements, device="cuda", dtype=torch.float32) + cache_killer.zero_() + + print("K,S0,Y_S1,BLOCK_S0,BLOCK_S1,median_ms,mean_ms,p90_ms", flush=True) + for s1 in _csv_ints(args.s1s): + for k in _csv_ints(args.ks): + for s0 in _csv_ints(args.s0s): + opt_flags = _select_reduce_forward_config(s0, s1, 1, k, False) + median_ms, mean_ms, p90_ms = bench_reduce(k, s0, s1, args.iters, cache_killer) + print( + f"{k},{s0},{s1},{opt_flags.block_s0},{opt_flags.block_x_s1},{median_ms:.6f},{mean_ms:.6f},{p90_ms:.6f}", + flush=True) + + +if __name__ == "__main__": + main() diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py index 3a2fc7e92755..5d65cd6225ce 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py @@ -133,7 +133,7 @@ def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch): def test_max_allowable_mn_and_split_k_constraints(monkeypatch): setup_nvidia(monkeypatch) - opt_flags._opt_flags = None + opt_flags.reset_opt_flags() opt_flags.reset_opt_flags_constraints() opt_flags.update_opt_flags_constraints( { @@ -167,7 +167,7 @@ def test_max_allowable_mn(monkeypatch): batch_size, m, n, k = 1, 256, 256, 256 def get_flags(split_k, max_mn): - opt_flags._opt_flags = None + opt_flags.reset_opt_flags() opt_flags.reset_opt_flags_constraints() opt_flags.update_opt_flags_constraints( { diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index 49c5977ffc1e..9de9a0e0bac2 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -21,7 +21,15 @@ from .tensor_details.layout_details.strided import StridedLayout from .tensor_details.layout_details.blackwell_scale import BlackwellActMXScaleLayout from .tensor_details.layout_details.blackwell_value_shuffled import BlackwellMX4ValueShuffledLayout -from .matmul_details.opt_flags import InapplicableConstraint, make_opt_flags, update_opt_flags_constraints +from .matmul_details.opt_flags import ( + InapplicableConstraint, + OptFlags as OptFlags, + make_opt_flags, + scoped_opt_flags as scoped_opt_flags, + scoped_opt_flags_constraints as scoped_opt_flags_constraints, + update_opt_flags_constraints, +) +from .matmul_details.opt_flags_details import opt_flags_nvidia from .specialize import FnSpecs, SpecializationModule, ClosureArg from .tensor import Storage, Tensor, FP4, wrap_torch_tensor, RaggedTensorMetadata, is_tma_compliant, make_tma, convert_layout from .tensor import dtype_to_torch_dtype, torch_dtype_to_dtype @@ -131,16 +139,9 @@ class PrecisionConfig: # TODO: merge in opt_flags def get_swap_xw(precision_config, opt_flags): - if target_info.cuda_capability_geq(10, 0): - if precision_config.b_mx_scale is not None: - return opt_flags.block_m <= 64 and opt_flags.is_persistent - else: - return opt_flags.block_m < 64 and opt_flags.is_persistent - elif target_info.cuda_capability_geq(9, 0): - b_scale_layout = None if not isinstance(precision_config.b_mx_scale, Tensor) else precision_config.b_mx_scale.storage.layout - return isinstance(b_scale_layout, HopperMXScaleLayout) - - return False + if triton.runtime.driver.active.get_current_target().backend != "cuda": + return False + return opt_flags_nvidia.compute_swap_xw(precision_config, opt_flags.block_m, opt_flags.is_persistent) # --------------------- # Allocation @@ -385,7 +386,8 @@ def matmul(a, b, bias, # which is too big. can_use_tma = False has_gather_tma = has_gather and target_info.has_tma_gather() - can_use_split_k = scatter_indx is None and not a_has_mx and not b_has_mx and ragged_dimension != "K" + is_ragged_mx = (a_has_mx or b_has_mx) and (is_a_ragged or is_b_ragged) + can_use_split_k = scatter_indx is None and not is_ragged_mx and ragged_dimension != "K" and c_acc_in is None and precision_config.c_mx_scale is None block_k = None if ragged_dimension == "K": block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 00fe57f681fe..f309a1246592 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -1,6 +1,7 @@ # isort: off # fmt: off from contextlib import contextmanager +from contextvars import ContextVar from dataclasses import dataclass import triton @@ -373,36 +374,46 @@ def _is_layout_strided(layout: Layout | None) -> bool: # User Interface # -------------- -_opt_flags_constraints: dict = dict() -_opt_flags: OptFlags | None = None +_opt_flags_constraints: ContextVar[dict | None] = ContextVar("opt_flags_constraints", default=None) +_opt_flags: ContextVar[OptFlags | None] = ContextVar("opt_flags", default=None) + +def _get_opt_flags_constraints() -> dict: + constraints = _opt_flags_constraints.get() + return {} if constraints is None else constraints def update_opt_flags_constraints(constraints: dict[str, int]): - global _opt_flags_constraints - _opt_flags_constraints.update(constraints) + updated = _get_opt_flags_constraints().copy() + updated.update(constraints) + _opt_flags_constraints.set(updated) def reset_opt_flags_constraints(): - global _opt_flags_constraints - _opt_flags_constraints = dict() + _opt_flags_constraints.set(None) @contextmanager def scoped_opt_flags_constraints(constraints): - saved = dict(_opt_flags_constraints) - _opt_flags_constraints.update(constraints) + updated = _get_opt_flags_constraints().copy() + updated.update(constraints) + token = _opt_flags_constraints.set(updated) try: yield finally: - _opt_flags_constraints.clear() - _opt_flags_constraints.update(saved) + _opt_flags_constraints.reset(token) def reset_opt_flags(): - global _opt_flags - _opt_flags = None + _opt_flags.set(None) def set_opt_flags(opt_flags: OptFlags): - global _opt_flags - assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override" - assert not _opt_flags, "opt_flags already set; please reset to None first" - _opt_flags = opt_flags + assert not _get_opt_flags_constraints(), "setting constraints is incompatible with manual flags override" + assert _opt_flags.get() is None, "opt_flags already set; please reset to None first" + _opt_flags.set(opt_flags) + +@contextmanager +def scoped_opt_flags(opt_flags: OptFlags): + token = _opt_flags.set(opt_flags) + try: + yield + finally: + _opt_flags.reset(token) class InapplicableConstraint(Exception): pass @@ -426,19 +437,20 @@ def make_opt_flags( mx_block_size=None, x_uses_tma_when_persistent=True, ): - if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: + opt_flags_constraints = _get_opt_flags_constraints() + if opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") - if _opt_flags_constraints.get("split_k") is not None and _opt_flags_constraints.get("split_k") > 1 and not can_use_split_k: + if opt_flags_constraints.get("split_k") is not None and opt_flags_constraints.get("split_k") > 1 and not can_use_split_k: raise InapplicableConstraint("cannot enforce `split_k=True` constraint") - if _opt_flags_constraints.get("max_allowable_mn"): - if not _opt_flags_constraints.get("split_k"): + if opt_flags_constraints.get("max_allowable_mn"): + if not opt_flags_constraints.get("split_k"): raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn") enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance - if _opt_flags is not None: - assert not _opt_flags_constraints + opt_flags = _opt_flags.get() + if opt_flags is not None: + assert not opt_flags_constraints assert block_k is None - return _opt_flags - opt_flags_constraints = _opt_flags_constraints + return opt_flags if block_k is not None: opt_flags_constraints = opt_flags_constraints.copy() opt_flags_constraints.update(block_k=block_k, split_k=1) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index 8ce440a1691a..c200afb0bde3 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -13,6 +13,20 @@ def is_x_scale_swizzled(precision_config): and isinstance(precision_config.a_mx_scale.storage.layout, BlackwellActMXScaleLayout)) +def compute_swap_xw(precision_config, block_m, is_persistent): + if target_info.cuda_capability_geq(10, 0): + if precision_config.b_mx_scale is not None: + return block_m <= 64 and is_persistent + else: + return block_m < 64 and is_persistent + elif target_info.cuda_capability_geq(9, 0): + layout = None if not isinstance(precision_config.b_mx_scale, + Tensor) else precision_config.b_mx_scale.storage.layout + return isinstance(layout, HopperMXScaleLayout) + + return False + + def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): if routing_data is not None and batch_size == 1: grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m) @@ -146,7 +160,17 @@ def compute_num_stages( # pipelined TMA store local to global, or # pipelined layout conversion before store of the accumulator # note: layout conversion has some padding - smem_capacity -= int((block_m + 4) * acc_block_n * acc_size) + epilogue_smem = int((block_m + 4) * acc_block_n * acc_size) + if compute_swap_xw(precision_config, block_m, is_persistent): + # SWAP_XW Blackwell kernels stage the full transposed TMEM + # accumulator tile through fp32 smem before converting/storing it. + # If the output is narrower, the final TMA-store tile is a separate + # smem allocation. + acc_smem = block_m * block_n * (FP32.bitwidth // 8) + if out_itemsize < (FP32.bitwidth // 8): + acc_smem += int(block_m * acc_block_n * out_itemsize) + epilogue_smem = max(epilogue_smem, acc_smem) + smem_capacity -= epilogue_smem if x_transpose: smem_capacity -= block_m * block_k * (max(8, lhs_dtype.bitwidth) // 8) @@ -155,7 +179,8 @@ def compute_num_stages( if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32): smem_capacity -= 32 * 1024 smem_capacity = max(smem_capacity, 0) - num_stages = min(smem_capacity // int(stage_size), 4) + max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested + num_stages = min(smem_capacity // int(stage_size), max_stages) # Keep one stage of headroom for persistent fp32 to avoid launch-time OOR. if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32): num_stages = min(num_stages, 3) diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index 0e6355bca57f..63b01aa445ac 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -1,10 +1,13 @@ from dataclasses import dataclass +from contextlib import contextmanager +from contextvars import ContextVar import torch import triton import triton.language as tl from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, quantize_mxfp4_fn, quantize_mxfp8_fn, quantize_nvfp4_fn from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale from triton_kernels.numerics import InFlexData, OutFlexData, MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 +from triton_kernels import target_info from typing import Optional from .specialize import SpecializationModule, ClosureArg, FnSpecs @@ -15,6 +18,27 @@ class PostprocessFn: fn_args: tuple[object] = tuple() +@dataclass(frozen=True) +class OptFlags: + block_s0: int + block_x_s1: int + block_y_s1: int + num_warps: int + use_static_loop: bool + + +_opt_flags: ContextVar[OptFlags | None] = ContextVar("reduce_opt_flags", default=None) + + +@contextmanager +def scoped_opt_flags(opt_flags: OptFlags): + token = _opt_flags.set(opt_flags) + try: + yield + finally: + _opt_flags.reset(token) + + # Return strides in this order: (reduction dim, non-reduction dim #0, non-reduction dim #1). def _get_strides(t, dim, strides=None): if t is None: @@ -51,6 +75,46 @@ def reduce_launch_metadata(grid, kernel, args): return ret +def _select_reduce_forward_config( + S0: int, + Y_S1: int, + reduction_n: int, + K: int, + has_mx: bool, +) -> OptFlags: + use_static_loop = K <= 8 + if K in (2, 3, 4) and S0 <= 256 and Y_S1 >= 4096 and reduction_n == 1 and not has_mx: + if K >= 3: + # K>=3 does more loads than K=2, so keep its tile area + # somewhat smaller. + target_elems = 512 + else: + if S0 <= 8: + target_elems = 512 + elif S0 <= 64: + target_elems = 1024 + elif S0 <= 128: + target_elems = 512 + else: + target_elems = 2048 + + # A full-wave floor can force very narrow S1 tiles for tiny S0. + # Keep this slightly below num_sms so we prefer contiguous S1 work + # when a tile is already close to filling the device. + min_programs = max(1, (3 * target_info.num_sms()) // 4) + block_s1 = min(512, target_elems) + while block_s1 >= 16: + max_block_s0 = min(S0, max(1, target_elems // block_s1)) + min_s0_programs = triton.cdiv(min_programs, triton.cdiv(Y_S1, block_s1)) + if min_s0_programs <= S0: + max_occupancy_block_s0 = max(1, S0 // min_s0_programs) + block_s0 = min(max_block_s0, max_occupancy_block_s0) + return OptFlags(block_s0, block_s1, block_s1, 4, use_static_loop) + block_s1 //= 2 + + return OptFlags(32, 128, 128 // reduction_n, 4, use_static_loop) + + @triton.jit(launch_metadata=reduce_launch_metadata) def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x tensor (input) XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale @@ -76,6 +140,7 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x SCALE_BROADCAST_R: tl.constexpr, # SCALE_BROADCAST_S0: tl.constexpr, # SCALE_BROADCAST_S1: tl.constexpr, # + USE_STATIC_LOOP: tl.constexpr, # BLOCK_S0: tl.constexpr, # BLOCK_X_S1: tl.constexpr, # BLOCK_Y_S1: tl.constexpr, # @@ -85,9 +150,9 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x ): pid_s0 = tl.program_id(0) pid_s1 = tl.program_id(1) - tl.static_assert(BLOCK_X_S1 % 32 == 0) - BLOCK_X_SMX1: tl.constexpr = BLOCK_X_S1 // 32 - BLOCK_Y_SMX1: tl.constexpr = BLOCK_Y_S1 // (1 if Y_MX_BLOCK_SIZE is None else Y_MX_BLOCK_SIZE) + tl.static_assert(XMx is None or BLOCK_X_S1 % 32 == 0) + BLOCK_X_SMX1: tl.constexpr = tl.cdiv(BLOCK_X_S1, 32) + BLOCK_Y_SMX1: tl.constexpr = tl.cdiv(BLOCK_Y_S1, 1 if Y_MX_BLOCK_SIZE is None else Y_MX_BLOCK_SIZE) offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) offs_x_s1 = pid_s1 * BLOCK_X_S1 + tl.arange(0, BLOCK_X_S1) offs_x_smx1 = pid_s1 * BLOCK_X_SMX1 + tl.arange(0, BLOCK_X_SMX1) @@ -100,10 +165,11 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x valid_s0 = offs_s0 < S0 valid_x_s1 = offs_x_s1 < X_S1 valid_in_smx1 = offs_x_smx1 < tl.cdiv(X_S1, 32) - y = tl.zeros((BLOCK_S0, BLOCK_X_S1), dtype=tl.float32) x_flex_scale = load_scale(XFlex) - for k in (tl.static_range if K <= 8 else tl.range)(0, K): - x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_x_s1[None, :] * stride_x1 + if not USE_STATIC_LOOP: + y = tl.zeros((BLOCK_S0, BLOCK_X_S1), dtype=tl.float32) + + for k in (tl.static_range if USE_STATIC_LOOP else tl.range)(0, K): mask = valid_s0[:, None] & valid_x_s1[None, :] if not IS_MASK_NONE: k_term = 0 if BROADCAST_R else (k * stride_mr) @@ -112,8 +178,8 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x m_ptrs = Mask + k_term + s0_term + s1_term m = tl.load(m_ptrs, mask=mask, other=1).to(tl.int1) mask &= m - x = tl.load(x_ptrs, mask=mask, other=0.0) - x = x.to(tl.float32) + x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_x_s1[None, :] * stride_x1 + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) if XMx is not None: xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_x_smx1[None, :] * stride_xmx1 xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_in_smx1[None, :], other=0.0) @@ -127,7 +193,10 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s s = tl.load(s_ptrs, mask=mask, other=1) x = tl.fma(x, s, 0.0) - y += x + if USE_STATIC_LOOP and k == 0: + y = x + else: + y += x if POSTPROCESS_FN1 is not None: y = POSTPROCESS_FN1(y, *postprocess_fn1_args) if XGlobalScale is not None: @@ -264,10 +333,11 @@ def reduce_forward( stride_sr, stride_s0, stride_s1 = _get_strides(scale, dim) K = x.shape[dim] # Always use the 2D tiled kernel with constexpr metaprogramming for mask broadcasting - BLOCK_S0 = 32 - BLOCK_X_S1 = 128 - BLOCK_Y_S1 = 128 // postprocess_fn1.specs.reduction_n - grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(Y_S1, BLOCK_Y_S1)) + opt_flags = _opt_flags.get() + if opt_flags is None: + opt_flags = _select_reduce_forward_config(S0, Y_S1, postprocess_fn1.specs.reduction_n, K, x_mxscale is not None + or y_has_mx) + grid = (triton.cdiv(S0, opt_flags.block_s0), triton.cdiv(Y_S1, opt_flags.block_y_s1)) if y_has_mx: if y_dtype == torch.float8_e4m3fn: postprocess_mx_fn = FnSpecs("quantize_mxfp8", quantize_mxfp8_fn, tuple(), tuple()) @@ -300,13 +370,14 @@ def reduce_forward( SCALE_BROADCAST_R=(stride_sr == 0), # SCALE_BROADCAST_S0=(stride_s0 == 0), # SCALE_BROADCAST_S1=(stride_s1 == 0), # - BLOCK_S0=BLOCK_S0, # - BLOCK_X_S1=BLOCK_X_S1, # - BLOCK_Y_S1=BLOCK_Y_S1, # + USE_STATIC_LOOP=opt_flags.use_static_loop, # + BLOCK_S0=opt_flags.block_s0, # + BLOCK_X_S1=opt_flags.block_x_s1, # + BLOCK_Y_S1=opt_flags.block_y_s1, # Y_MX_BLOCK_SIZE=y_microblock_size, # Y_VALUE_PACK_FACTOR=y_value_pack_factor, # DIM=dim, # - num_warps=4 # + num_warps=opt_flags.num_warps # ) return y, y_mxscale