Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ def test_set_idle_sms():
from triton_kernels.matmul_details.opt_flags import make_opt_flags
num_idle_sms = 24
matmul_set_idle_sms(num_idle_sms)
flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \
1, 1024, 1024, 1024, None, True, False, 1, False, False, None)
assert flags.idle_sms == num_idle_sms
try:
flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \
1, 1024, 1024, 1024, None, True, False, 1, False, False, None)
assert flags.idle_sms == num_idle_sms
with opt_flags.scoped_opt_flags_constraints({"idle_sms": num_idle_sms + 1}):
flags = make_opt_flags(FP32, FP32, FP32, PrecisionConfig(), \
1, 1024, 1024, 1024, None, True, False, 1, False, False, None)
assert flags.idle_sms == num_idle_sms + 1
finally:
matmul_set_idle_sms(0)
4 changes: 2 additions & 2 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
make_opt_flags,
scoped_opt_flags as scoped_opt_flags,
scoped_opt_flags_constraints as scoped_opt_flags_constraints,
update_opt_flags_constraints,
set_idle_sms,
)
from .matmul_details.opt_flags_details import opt_flags_nvidia
from .specialize import FnSpecs, SpecializationModule, ClosureArg
Expand Down Expand Up @@ -235,7 +235,7 @@ def matmul_set_idle_sms(num_idle_sms):
"""
persistent kernels will leave `num_idle_sms` idle
"""
update_opt_flags_constraints({"idle_sms": num_idle_sms})
set_idle_sms(num_idle_sms)

def matmul(a, b, bias,
a_ragged_metadata: RaggedTensorMetadata | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_default_opt_flags_amd(
has_y_acc_in,
constraints,
):
constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn", "num_warps", "disable_mx4_block_swap"}
constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "idle_sms", "max_allowable_mn", "num_warps", "disable_mx4_block_swap"}
unsupported = set(constraints.keys()) - constraints_supported
assert not unsupported, f"Given unsupported constraint: {unsupported}"
# tokens per slice
Expand Down Expand Up @@ -166,7 +166,7 @@ def replace_with_valid_constraint(k: str, v):
w_cache_modifier=w_cache_modifier,
split_k=split_k,
is_persistent=is_persistent,
idle_sms=0,
idle_sms=constraints.get("idle_sms", _get_idle_sms()),
Comment thread
aeng-openai marked this conversation as resolved.
epilogue_subtile=epilogue_subtile,
arch=None,
target_kernel_kwargs=target_kernel_kwargs,
Expand Down Expand Up @@ -388,7 +388,7 @@ def _is_layout_strided(layout: Layout | None) -> bool:
# For some reason, overlapping the epilogue is slower for hopper bf16 x mxfp4
FLATTEN_LOOPS=not is_hopper_scale,
),
idle_sms=constraints.get("idle_sms", 0),
idle_sms=constraints.get("idle_sms", _get_idle_sms()),
occupancy_target=occupancy_target,
)
# check constraints
Expand All @@ -401,6 +401,14 @@ def _is_layout_strided(layout: Layout | None) -> bool:

_opt_flags_constraints: ContextVar[dict | None] = ContextVar("opt_flags_constraints", default=None)
_opt_flags: ContextVar[OptFlags | None] = ContextVar("opt_flags", default=None)
_idle_sms = 0

def _get_idle_sms() -> int:
return _idle_sms

def set_idle_sms(num_idle_sms: int):
global _idle_sms
_idle_sms = num_idle_sms

def _get_opt_flags_constraints() -> dict:
constraints = _opt_flags_constraints.get()
Expand Down
Loading