From 93c9c8af23e7e0307fe729878a733202029ccac6 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Tue, 26 May 2026 10:42:10 -0700 Subject: [PATCH 1/2] [KERNELS] make setting idle sms process-global they can still separately be overriden in the opt flags constraints --- python/triton_kernels/tests/test_matmul.py | 13 ++++++++++--- python/triton_kernels/triton_kernels/matmul.py | 4 ++-- .../triton_kernels/matmul_details/opt_flags.py | 12 ++++++++++-- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 62c1e257fcd2..a4acc435a19d 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -585,6 +585,13 @@ def test_set_idle_sms(): from triton_kernels.matmul_details.opt_flags import make_opt_flags num_idle_sms = 24 matmul_set_idle_sms(num_idle_sms) - flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \ - 1, 1024, 1024, 1024, None, True, False, 1, False, False, None) - assert flags.idle_sms == num_idle_sms + try: + flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \ + 1, 1024, 1024, 1024, None, True, False, 1, False, False, None) + assert flags.idle_sms == num_idle_sms + with opt_flags.scoped_opt_flags_constraints({"idle_sms": num_idle_sms + 1}): + flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \ + 1, 1024, 1024, 1024, None, True, False, 1, False, False, None) + assert flags.idle_sms == num_idle_sms + 1 + finally: + matmul_set_idle_sms(0) diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index 4ff4cc2ea360..aee2fd1bb14b 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -27,7 +27,7 @@ make_opt_flags, scoped_opt_flags as scoped_opt_flags, scoped_opt_flags_constraints as scoped_opt_flags_constraints, - update_opt_flags_constraints, + set_idle_sms, ) from .matmul_details.opt_flags_details import opt_flags_nvidia from .specialize import FnSpecs, SpecializationModule, ClosureArg @@ -235,7 +235,7 @@ def matmul_set_idle_sms(num_idle_sms): """ persistent kernels will leave `num_idle_sms` idle """ - update_opt_flags_constraints({"idle_sms": num_idle_sms}) + set_idle_sms(num_idle_sms) def matmul(a, b, bias, a_ragged_metadata: RaggedTensorMetadata | None = None, diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index fb0223722ddc..3c2f293931d4 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -166,7 +166,7 @@ def replace_with_valid_constraint(k: str, v): w_cache_modifier=w_cache_modifier, split_k=split_k, is_persistent=is_persistent, - idle_sms=0, + idle_sms=constraints.get("idle_sms", _get_idle_sms()), epilogue_subtile=epilogue_subtile, arch=None, target_kernel_kwargs=target_kernel_kwargs, @@ -388,7 +388,7 @@ def _is_layout_strided(layout: Layout | None) -> bool: # For some reason, overlapping the epilogue is slower for hopper bf16 x mxfp4 FLATTEN_LOOPS=not is_hopper_scale, ), - idle_sms=constraints.get("idle_sms", 0), + idle_sms=constraints.get("idle_sms", _get_idle_sms()), occupancy_target=occupancy_target, ) # check constraints @@ -401,6 +401,14 @@ def _is_layout_strided(layout: Layout | None) -> bool: _opt_flags_constraints: ContextVar[dict | None] = ContextVar("opt_flags_constraints", default=None) _opt_flags: ContextVar[OptFlags | None] = ContextVar("opt_flags", default=None) +_idle_sms = 0 + +def _get_idle_sms() -> int: + return _idle_sms + +def set_idle_sms(num_idle_sms: int): + global _idle_sms + _idle_sms = num_idle_sms def _get_opt_flags_constraints() -> dict: constraints = _opt_flags_constraints.get() From febcaf139e8c0eaa76f006bcb5556e3121783367 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Tue, 26 May 2026 10:50:19 -0700 Subject: [PATCH 2/2] fix amd --- .../triton_kernels/triton_kernels/matmul_details/opt_flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 3c2f293931d4..964c4a93af0a 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -68,7 +68,7 @@ def make_default_opt_flags_amd( has_y_acc_in, constraints, ): - constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn", "num_warps", "disable_mx4_block_swap"} + constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "idle_sms", "max_allowable_mn", "num_warps", "disable_mx4_block_swap"} unsupported = set(constraints.keys()) - constraints_supported assert not unsupported, f"Given unsupported constraint: {unsupported}" # tokens per slice