Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def autotune(
cache: Optional[str] = None,
tuning_buckets: Optional[Tuple[int, ...]] = None,
round_up: Optional[bool] = None,
skip_ops: Optional[Union[str, Set[str]]] = None,
):
"""Context manager for autotuning with optional file-based caching.

Expand Down Expand Up @@ -519,6 +520,22 @@ def autotune(
kernel for a larger bucket also performs well at nearby smaller
sizes (see the PR discussion for benchmark data on cuDNN plans).

skip_ops: Optional set of ``custom_op`` names to exclude from
autotuning. Operations whose ``custom_op`` string matches an
entry in this set will skip profiling entirely and use their
fallback (heuristic) tactic instead. This is useful when a
framework runs a single dummy forward pass inside
``autotune()`` but wants to avoid the compilation cost of
autotuning specific ops whose heuristics are already
near-optimal. For example,
``skip_ops={"fp4_gemm"}`` skips autotuning for ``mm_fp4``
while still tuning MoE and other operations.
Nested contexts **union** their skip sets: an inner
``autotune(skip_ops={"B"})`` inside an outer
``autotune(skip_ops={"A"})`` skips both ``"A"`` and ``"B"``.
Common op names: ``"fp4_gemm"``, ``"bf16_gemm"``,
``"fp8_gemm"``, ``"mxfp8_gemm"``.

Raises:
ValueError: If ``tuning_buckets`` is provided but empty.

Expand Down Expand Up @@ -572,6 +589,10 @@ def autotune(
with autotune(True, tuning_buckets=(64, 512)):
model(inputs) # uses (64, 512)
model(inputs) # back to (128, 256)

# Skip autotuning for specific ops (use heuristic fallback)
with autotune(True, skip_ops={"fp4_gemm"}):
model(inputs) # mm_fp4 uses heuristic, other ops are autotuned
"""
tuner = AutoTuner.get()

Expand All @@ -593,6 +614,14 @@ def autotune(
if os.path.isfile(cache):
cache_valid = tuner.load_configs(cache)

# Push skip_ops onto per-thread stack. Each entry is the cumulative
# union so that _effective_skip_ops is an O(1) read from the top.
skip_ops_stack = tuner._get_skip_ops_stack()
if skip_ops is not None:
skip_ops_set = {skip_ops} if isinstance(skip_ops, str) else skip_ops
current = skip_ops_stack[-1] if skip_ops_stack else frozenset()
skip_ops_stack.append(current | frozenset(skip_ops_set))

# Push tuning bucket overrides onto per-thread stack. Inherits from the
# current top-of-stack when a parameter is not explicitly supplied.
override_stack = tuner._get_override_stack()
Expand Down Expand Up @@ -623,6 +652,8 @@ def autotune(
except BaseException:
if pushed:
override_stack.pop()
if skip_ops is not None:
skip_ops_stack.pop()
raise

try:
Expand All @@ -633,9 +664,11 @@ def autotune(
tuner._active_tuning_contexts -= 1
tuner.is_tuning_mode = tuner._active_tuning_contexts > 0

# Pop the override we pushed (thread-local, no lock needed).
# Pop the overrides we pushed (thread-local, no lock needed).
if pushed:
override_stack.pop()
if skip_ops is not None:
skip_ops_stack.pop()

if autotune_enabled:
logger.info("[Autotuner]: Autotuning process ends")
Expand Down Expand Up @@ -812,6 +845,8 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
# autotune() context manager. Using threading.local ensures concurrent
# autotune() contexts on different threads don't clobber each other.
self._override_local = threading.local()
# Per-thread stack of frozenset[str] for skip_ops overrides.
self._skip_ops_local = threading.local()
# Cache overridden TuningConfig objects to keep stable object identity
# for _find_nearest_profile's LRU cache.
# Two-level: WeakKeyDictionary[TuningConfig, Dict[(buckets, round_up), TuningConfig]]
Expand All @@ -828,6 +863,19 @@ def _get_override_stack(self) -> List:
local.stack = []
return local.stack

def _get_skip_ops_stack(self) -> List:
"""Return the per-thread skip_ops stack, creating it on first access."""
local = self._skip_ops_local
if not hasattr(local, "stack"):
local.stack = []
return local.stack

@property
def _effective_skip_ops(self) -> frozenset:
"""Cumulative union of all skip_ops from the current thread's stack."""
stack = self._get_skip_ops_stack()
return stack[-1] if stack else frozenset()

@property
def _override_tuning_buckets(self) -> Optional[Tuple[int, ...]]:
"""Currently active tuning-bucket override for this thread, or ``None``."""
Expand Down Expand Up @@ -1077,6 +1125,20 @@ def choose_one(
# Note: this is a single global lock, so multi-threaded tuning on
# separate GPUs is serialized. Use multi-process (one per GPU) for
# parallel multi-GPU tuning.
# Skip profiling for ops in the skip_ops set — return fallback
# immediately. The fallback runner (runners[0], tactic=-1) uses
# the op's built-in heuristic, avoiding kernel compilation.
# Checked before acquiring the lock since _effective_skip_ops is
# thread-local and does not touch shared state.
if custom_op in self._effective_skip_ops:
logger.debug(
f"[AutoTuner]: Skipping autotuning for '{custom_op}' "
f"(in skip_ops). Using fallback tactic."
)
if not runners:
raise ValueError(f"No runners provided for op '{custom_op}'")
return runners[0], -1
Comment on lines +1128 to +1140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate runners once at method entry, not only in the skip path.

Line 1136 adds an empty-list guard only for skipped ops. For non-skipped ops, Line 1152 still does runners[runner_id], so runners=[] can raise IndexError in inference mode.

Suggested fix
 def choose_one(
     self,
     custom_op: str,
     runners: List[TunableRunner],
     tuning_config: TuningConfig,
     inputs: List[torch.Tensor],
     **kwargs,
 ) -> Tuple[TunableRunner, int]:
+    if not runners:
+        raise ValueError(f"No runners provided for op '{custom_op}'")
+    if not all(isinstance(r, TunableRunner) for r in runners):
+        raise TypeError("All given runners must be subclasses of TunableRunner")
+
     # Skip profiling for ops in the skip_ops set — return fallback
     # immediately.
     if custom_op in self._effective_skip_ops:
         logger.debug(
             f"[AutoTuner]: Skipping autotuning for '{custom_op}' "
             f"(in skip_ops). Using fallback tactic."
         )
-        if not runners:
-            raise ValueError(f"No runners provided for op '{custom_op}'")
         return runners[0], -1
 ...
-    assert len(runners) > 0, "At least one runner is required"
-    assert all([isinstance(r, TunableRunner) for r in runners]), (
-        "All Given runners must be subclass of TunableRunner"
-    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/autotuner.py` around lines 1126 - 1138, The method currently only
checks for an empty runners list inside the skip_ops branch, which leaves
non-skipped paths trying to index runners (e.g., runners[runner_id]) and can
raise IndexError; at the start of the enclosing method (the function that
references custom_op, _effective_skip_ops, runners and runner_id), validate that
runners is non-empty and raise a clear ValueError if empty (same style as the
skip path) before any use of runners or early returns, so both skipped and
non-skipped flows are safe.


with self._lock:
# Apply tuning bucket / rounding overrides from autotune() context.
if self._override_tuning_buckets is not None or self._override_round_up:
Expand Down
177 changes: 177 additions & 0 deletions tests/autotuner/test_autotuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,180 @@ def test_choose_one_with_none_input_no_crash():
)
assert chosen_runner is runner
assert tactic == -1


# ---------------------------------------------------------------------------
# Tests: skip_ops
# ---------------------------------------------------------------------------


def test_skip_ops_prevents_profiling(monkeypatch):
"""Skipped ops should return fallback immediately without profiling."""
tuner = reset_autotuner()
runner = DummyRunner(valid_tactics=(0, 1, 2))
inputs = [torch.empty((16, 32), dtype=torch.float32)]
config = TuningConfig()

profile_calls = []

def fake_profile(
self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs
):
profile_calls.append(tactic)
return 1.0

monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile)

with autotune(tune_mode=True, skip_ops={"skip_me"}):
chosen_runner, tactic = tuner.choose_one("skip_me", [runner], config, inputs)

assert chosen_runner is runner
assert tactic == -1
assert len(profile_calls) == 0


def test_skip_ops_does_not_affect_other_ops(monkeypatch):
"""Non-skipped ops should still be profiled normally."""
tuner = reset_autotuner()
runner = DummyRunner(valid_tactics=(0, 1))
inputs = [torch.empty((16, 32), dtype=torch.float32)]
config = TuningConfig()

def fake_profile(
self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs
):
return {0: 5.0, 1: 1.0}[tactic]

monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile)

with autotune(tune_mode=True, skip_ops={"some_other_op"}):
chosen_runner, tactic = tuner.choose_one("tune_me", [runner], config, inputs)

assert chosen_runner is runner
assert tactic == 1 # best tactic selected via profiling


def test_skip_ops_nested_union(monkeypatch):
"""Nested autotune contexts should union their skip_ops sets."""
tuner = reset_autotuner()
runner = DummyRunner(valid_tactics=(0,))
inputs = [torch.empty((4, 8), dtype=torch.float32)]
config = TuningConfig()

def fake_profile(
self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs
):
return 1.0

monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile)

with autotune(tune_mode=True, skip_ops={"op_a"}):
# op_a should be skipped
_, tactic_a = tuner.choose_one("op_a", [runner], config, inputs)
assert tactic_a == -1

with autotune(tune_mode=True, skip_ops={"op_b"}):
# Both op_a and op_b should be skipped in inner context
_, tactic_a2 = tuner.choose_one("op_a", [runner], config, inputs)
assert tactic_a2 == -1
_, tactic_b = tuner.choose_one("op_b", [runner], config, inputs)
assert tactic_b == -1

# After inner context exits, only op_a should still be skipped
_, tactic_a3 = tuner.choose_one("op_a", [runner], config, inputs)
assert tactic_a3 == -1


def test_skip_ops_returns_first_runner():
"""Skipped ops should always return runners[0], even with multiple runners."""
tuner = reset_autotuner()
runner_a = DummyRunner(valid_tactics=(0,))
runner_b = DummyRunner(valid_tactics=(1,))
inputs = [torch.empty((4, 8), dtype=torch.float32)]
config = TuningConfig()

with autotune(tune_mode=True, skip_ops={"multi_runner_op"}):
chosen, tactic = tuner.choose_one(
"multi_runner_op", [runner_a, runner_b], config, inputs
)

assert chosen is runner_a
assert tactic == -1


def test_skip_ops_empty_set_is_noop(monkeypatch):
"""skip_ops=set() should not skip anything."""
tuner = reset_autotuner()
runner = DummyRunner(valid_tactics=(0, 1))
inputs = [torch.empty((16, 32), dtype=torch.float32)]
config = TuningConfig()

def fake_profile(
self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs
):
return {0: 5.0, 1: 1.0}[tactic]

monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile)

with autotune(tune_mode=True, skip_ops=set()):
chosen, tactic = tuner.choose_one("should_tune", [runner], config, inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix Ruff RUF059 by discarding unused unpacked values.

Line 957 (chosen) and Line 986 (tactic_b2) are assigned but unused.

Suggested cleanup
-    with autotune(tune_mode=True, skip_ops=set()):
-        chosen, tactic = tuner.choose_one("should_tune", [runner], config, inputs)
+    with autotune(tune_mode=True, skip_ops=set()):
+        _, tactic = tuner.choose_one("should_tune", [runner], config, inputs)
...
-        _, tactic_b2 = tuner.choose_one("op_b", [runner], config, inputs)
+        _, _ = tuner.choose_one("op_b", [runner], config, inputs)

Also applies to: 986-986

🧰 Tools
🪛 Ruff (0.15.14)

[warning] 957-957: Unpacked variable chosen is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/autotuner/test_autotuner_core.py` at line 957, The test assigns unused
unpacked values causing Ruff RUF059: when calling tuner.choose_one in
tests/autotuner/test_autotuner_core.py (the call that currently does "chosen,
tactic = tuner.choose_one(...)" and the similar assignment that creates
"tactic_b2"), discard the unused value(s) by replacing the unused variable with
an underscore (e.g., use "_, tactic = tuner.choose_one(...)" or "tactic_b2, _ =
..." as appropriate) or by only assigning the needed element (e.g., assign the
return to a single name and index into it), keeping the call to choose_one
intact but removing the unused symbol(s).


assert tactic == 1 # profiled and selected best


def test_skip_ops_nested_inner_op_resumes_after_exit(monkeypatch):
"""op_b added by inner context should be profiled again after inner exits."""
tuner = reset_autotuner()
runner = DummyRunner(valid_tactics=(0,))
inputs = [torch.empty((4, 8), dtype=torch.float32)]
config = TuningConfig()

profile_calls = []

def fake_profile(
self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs
):
profile_calls.append(1)
return 1.0

monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile)

with autotune(tune_mode=True, skip_ops={"op_a"}):
with autotune(tune_mode=True, skip_ops={"op_b"}):
_, tactic_b = tuner.choose_one("op_b", [runner], config, inputs)
assert tactic_b == -1 # skipped in inner

# After inner exits, op_b should be profiled
profile_calls.clear()
_, tactic_b2 = tuner.choose_one("op_b", [runner], config, inputs)
assert len(profile_calls) > 0 # was profiled


def test_skip_ops_does_not_pollute_cache():
"""Skipped ops should not create entries in profiling_cache."""
tuner = reset_autotuner()
runner = DummyRunner()
inputs = [torch.empty((4, 8), dtype=torch.float32)]
config = TuningConfig()

cache_before = len(tuner.profiling_cache)

with autotune(tune_mode=True, skip_ops={"no_cache_op"}):
tuner.choose_one("no_cache_op", [runner], config, inputs)

assert len(tuner.profiling_cache) == cache_before


def test_skip_ops_restored_after_context():
"""skip_ops should be fully cleared after context exits."""
tuner = reset_autotuner()
runner = DummyRunner()
inputs = [torch.empty((4, 8), dtype=torch.float32)]
config = TuningConfig()

with autotune(tune_mode=False, skip_ops={"some_op"}):
_, tactic = tuner.choose_one("some_op", [runner], config, inputs)
assert tactic == -1

# After context, skip_ops should be empty — op goes through normal path
assert tuner._effective_skip_ops == frozenset()
Loading