Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions python/triton_kernels/bench/bench_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
import statistics

import torch

from triton_kernels.reduce import reduce, _select_reduce_forward_config


def _csv_ints(s):
return [int(x) for x in s.split(",") if x]


def _flush_cache(cache_killer):
if cache_killer is not None:
cache_killer.add_(1.0)


def bench_reduce(k, s0, s1, iters, cache_killer):
x = torch.randn((k, s0, s1), device="cuda", dtype=torch.float32)
for _ in range(10):
_flush_cache(cache_killer)
reduce(x, dim=0, y_dtype=torch.bfloat16)
torch.cuda.synchronize()

times_ms = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(iters):
_flush_cache(cache_killer)
start.record()
reduce(x, dim=0, y_dtype=torch.bfloat16)
end.record()
torch.cuda.synchronize()
times_ms.append(start.elapsed_time(end))
times_ms.sort()
return statistics.median(times_ms), statistics.mean(times_ms), times_ms[int(0.9 * (iters - 1))]


def main():
parser = argparse.ArgumentParser(description="Benchmark wide-S1 reduce_forward shapes.")
parser.add_argument("--ks", default="1,2,3,4,5,6,7,8")
parser.add_argument("--s0s", default="1,2,4,8,16,32,64,128,256")
parser.add_argument("--s1s", default="1024,2048,4096,8192,16384,32768")
parser.add_argument("--iters", type=int, default=80)
parser.add_argument(
"--flush-mb",
type=int,
default=512,
help="Touch this many MiB before each measured reduce. Set to 0 to benchmark hot-cache repeats.",
)
args = parser.parse_args()

cache_killer = None
if args.flush_mb > 0:
n_elements = args.flush_mb * 1024 * 1024 // torch.empty((), dtype=torch.float32).element_size()
cache_killer = torch.empty(n_elements, device="cuda", dtype=torch.float32)
cache_killer.zero_()

print("K,S0,Y_S1,BLOCK_S0,BLOCK_S1,median_ms,mean_ms,p90_ms", flush=True)
for s1 in _csv_ints(args.s1s):
for k in _csv_ints(args.ks):
for s0 in _csv_ints(args.s0s):
opt_flags = _select_reduce_forward_config(s0, s1, 1, k, False)
median_ms, mean_ms, p90_ms = bench_reduce(k, s0, s1, args.iters, cache_killer)
print(
f"{k},{s0},{s1},{opt_flags.block_s0},{opt_flags.block_x_s1},{median_ms:.6f},{mean_ms:.6f},{p90_ms:.6f}",
flush=True)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_make_default_opt_flags_nvidia_split_k_constraint(monkeypatch):
def test_max_allowable_mn_and_split_k_constraints(monkeypatch):
setup_nvidia(monkeypatch)

opt_flags._opt_flags = None
opt_flags.reset_opt_flags()
opt_flags.reset_opt_flags_constraints()
opt_flags.update_opt_flags_constraints(
{
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_max_allowable_mn(monkeypatch):
batch_size, m, n, k = 1, 256, 256, 256

def get_flags(split_k, max_mn):
opt_flags._opt_flags = None
opt_flags.reset_opt_flags()
opt_flags.reset_opt_flags_constraints()
opt_flags.update_opt_flags_constraints(
{
Expand Down
26 changes: 14 additions & 12 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
from .tensor_details.layout_details.strided import StridedLayout
from .tensor_details.layout_details.blackwell_scale import BlackwellActMXScaleLayout
from .tensor_details.layout_details.blackwell_value_shuffled import BlackwellMX4ValueShuffledLayout
from .matmul_details.opt_flags import InapplicableConstraint, make_opt_flags, update_opt_flags_constraints
from .matmul_details.opt_flags import (
InapplicableConstraint,
OptFlags as OptFlags,
make_opt_flags,
scoped_opt_flags as scoped_opt_flags,
scoped_opt_flags_constraints as scoped_opt_flags_constraints,
update_opt_flags_constraints,
)
from .matmul_details.opt_flags_details import opt_flags_nvidia
from .specialize import FnSpecs, SpecializationModule, ClosureArg
from .tensor import Storage, Tensor, FP4, wrap_torch_tensor, RaggedTensorMetadata, is_tma_compliant, make_tma, convert_layout
from .tensor import dtype_to_torch_dtype, torch_dtype_to_dtype
Expand Down Expand Up @@ -131,16 +139,9 @@ class PrecisionConfig:

# TODO: merge in opt_flags
def get_swap_xw(precision_config, opt_flags):
if target_info.cuda_capability_geq(10, 0):
if precision_config.b_mx_scale is not None:
return opt_flags.block_m <= 64 and opt_flags.is_persistent
else:
return opt_flags.block_m < 64 and opt_flags.is_persistent
elif target_info.cuda_capability_geq(9, 0):
b_scale_layout = None if not isinstance(precision_config.b_mx_scale, Tensor) else precision_config.b_mx_scale.storage.layout
return isinstance(b_scale_layout, HopperMXScaleLayout)

return False
if triton.runtime.driver.active.get_current_target().backend != "cuda":
return False
return opt_flags_nvidia.compute_swap_xw(precision_config, opt_flags.block_m, opt_flags.is_persistent)

# ---------------------
# Allocation
Expand Down Expand Up @@ -385,7 +386,8 @@ def matmul(a, b, bias,
# which is too big.
can_use_tma = False
has_gather_tma = has_gather and target_info.has_tma_gather()
can_use_split_k = scatter_indx is None and not a_has_mx and not b_has_mx and ragged_dimension != "K"
is_ragged_mx = (a_has_mx or b_has_mx) and (is_a_ragged or is_b_ragged)
can_use_split_k = scatter_indx is None and not is_ragged_mx and ragged_dimension != "K" and c_acc_in is None and precision_config.c_mx_scale is None
block_k = None
if ragged_dimension == "K":
block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility
Expand Down
60 changes: 36 additions & 24 deletions python/triton_kernels/triton_kernels/matmul_details/opt_flags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# isort: off
# fmt: off
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass

import triton
Expand Down Expand Up @@ -373,36 +374,46 @@ def _is_layout_strided(layout: Layout | None) -> bool:
# User Interface
# --------------

_opt_flags_constraints: dict = dict()
_opt_flags: OptFlags | None = None
_opt_flags_constraints: ContextVar[dict | None] = ContextVar("opt_flags_constraints", default=None)
_opt_flags: ContextVar[OptFlags | None] = ContextVar("opt_flags", default=None)

def _get_opt_flags_constraints() -> dict:
constraints = _opt_flags_constraints.get()
return {} if constraints is None else constraints

def update_opt_flags_constraints(constraints: dict[str, int]):
global _opt_flags_constraints
_opt_flags_constraints.update(constraints)
updated = _get_opt_flags_constraints().copy()
updated.update(constraints)
_opt_flags_constraints.set(updated)

def reset_opt_flags_constraints():
global _opt_flags_constraints
_opt_flags_constraints = dict()
_opt_flags_constraints.set(None)

@contextmanager
def scoped_opt_flags_constraints(constraints):
saved = dict(_opt_flags_constraints)
_opt_flags_constraints.update(constraints)
updated = _get_opt_flags_constraints().copy()
updated.update(constraints)
token = _opt_flags_constraints.set(updated)
try:
yield
finally:
_opt_flags_constraints.clear()
_opt_flags_constraints.update(saved)
_opt_flags_constraints.reset(token)

def reset_opt_flags():
global _opt_flags
_opt_flags = None
_opt_flags.set(None)

def set_opt_flags(opt_flags: OptFlags):
global _opt_flags
assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
assert not _opt_flags, "opt_flags already set; please reset to None first"
_opt_flags = opt_flags
assert not _get_opt_flags_constraints(), "setting constraints is incompatible with manual flags override"
assert _opt_flags.get() is None, "opt_flags already set; please reset to None first"
_opt_flags.set(opt_flags)

@contextmanager
def scoped_opt_flags(opt_flags: OptFlags):
token = _opt_flags.set(opt_flags)
try:
yield
finally:
_opt_flags.reset(token)

class InapplicableConstraint(Exception):
pass
Expand All @@ -426,19 +437,20 @@ def make_opt_flags(
mx_block_size=None,
x_uses_tma_when_persistent=True,
):
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
opt_flags_constraints = _get_opt_flags_constraints()
if opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
if _opt_flags_constraints.get("split_k") is not None and _opt_flags_constraints.get("split_k") > 1 and not can_use_split_k:
if opt_flags_constraints.get("split_k") is not None and opt_flags_constraints.get("split_k") > 1 and not can_use_split_k:
raise InapplicableConstraint("cannot enforce `split_k=True` constraint")
if _opt_flags_constraints.get("max_allowable_mn"):
if not _opt_flags_constraints.get("split_k"):
if opt_flags_constraints.get("max_allowable_mn"):
if not opt_flags_constraints.get("split_k"):
raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn")
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
if _opt_flags is not None:
assert not _opt_flags_constraints
opt_flags = _opt_flags.get()
if opt_flags is not None:
assert not opt_flags_constraints
assert block_k is None
return _opt_flags
opt_flags_constraints = _opt_flags_constraints
return opt_flags
if block_k is not None:
opt_flags_constraints = opt_flags_constraints.copy()
opt_flags_constraints.update(block_k=block_k, split_k=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ def is_x_scale_swizzled(precision_config):
and isinstance(precision_config.a_mx_scale.storage.layout, BlackwellActMXScaleLayout))


def compute_swap_xw(precision_config, block_m, is_persistent):
if target_info.cuda_capability_geq(10, 0):
if precision_config.b_mx_scale is not None:
return block_m <= 64 and is_persistent
else:
return block_m < 64 and is_persistent
elif target_info.cuda_capability_geq(9, 0):
layout = None if not isinstance(precision_config.b_mx_scale,
Tensor) else precision_config.b_mx_scale.storage.layout
return isinstance(layout, HopperMXScaleLayout)

return False


def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
if routing_data is not None and batch_size == 1:
grid_m = routing_data.n_blocks(routing_data.n_slices, m, block_m)
Expand Down Expand Up @@ -146,7 +160,17 @@ def compute_num_stages(
# pipelined TMA store local to global, or
# pipelined layout conversion before store of the accumulator
# note: layout conversion has some padding
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
epilogue_smem = int((block_m + 4) * acc_block_n * acc_size)
if compute_swap_xw(precision_config, block_m, is_persistent):
# SWAP_XW Blackwell kernels stage the full transposed TMEM
# accumulator tile through fp32 smem before converting/storing it.
# If the output is narrower, the final TMA-store tile is a separate
# smem allocation.
acc_smem = block_m * block_n * (FP32.bitwidth // 8)
if out_itemsize < (FP32.bitwidth // 8):
acc_smem += int(block_m * acc_block_n * out_itemsize)
epilogue_smem = max(epilogue_smem, acc_smem)
smem_capacity -= epilogue_smem
if x_transpose:
smem_capacity -= block_m * block_k * (max(8, lhs_dtype.bitwidth) // 8)

Expand All @@ -155,7 +179,8 @@ def compute_num_stages(
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
smem_capacity -= 32 * 1024
smem_capacity = max(smem_capacity, 0)
num_stages = min(smem_capacity // int(stage_size), 4)
max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested
num_stages = min(smem_capacity // int(stage_size), max_stages)
# Keep one stage of headroom for persistent fp32 to avoid launch-time OOR.
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
num_stages = min(num_stages, 3)
Expand Down
Loading
Loading