diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 6ee1f7bb5e..e8fa129e93 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -467,6 +467,7 @@ def autotune( cache: Optional[str] = None, tuning_buckets: Optional[Tuple[int, ...]] = None, round_up: Optional[bool] = None, + skip_ops: Optional[Union[str, Set[str]]] = None, ): """Context manager for autotuning with optional file-based caching. @@ -519,6 +520,22 @@ def autotune( kernel for a larger bucket also performs well at nearby smaller sizes (see the PR discussion for benchmark data on cuDNN plans). + skip_ops: Optional set of ``custom_op`` names to exclude from + autotuning. Operations whose ``custom_op`` string matches an + entry in this set will skip profiling entirely and use their + fallback (heuristic) tactic instead. This is useful when a + framework runs a single dummy forward pass inside + ``autotune()`` but wants to avoid the compilation cost of + autotuning specific ops whose heuristics are already + near-optimal. For example, + ``skip_ops={"fp4_gemm"}`` skips autotuning for ``mm_fp4`` + while still tuning MoE and other operations. + Nested contexts **union** their skip sets: an inner + ``autotune(skip_ops={"B"})`` inside an outer + ``autotune(skip_ops={"A"})`` skips both ``"A"`` and ``"B"``. + Common op names: ``"fp4_gemm"``, ``"bf16_gemm"``, + ``"fp8_gemm"``, ``"mxfp8_gemm"``. + Raises: ValueError: If ``tuning_buckets`` is provided but empty. @@ -572,6 +589,10 @@ def autotune( with autotune(True, tuning_buckets=(64, 512)): model(inputs) # uses (64, 512) model(inputs) # back to (128, 256) + + # Skip autotuning for specific ops (use heuristic fallback) + with autotune(True, skip_ops={"fp4_gemm"}): + model(inputs) # mm_fp4 uses heuristic, other ops are autotuned """ tuner = AutoTuner.get() @@ -593,6 +614,14 @@ def autotune( if os.path.isfile(cache): cache_valid = tuner.load_configs(cache) + # Push skip_ops onto per-thread stack. Each entry is the cumulative + # union so that _effective_skip_ops is an O(1) read from the top. + skip_ops_stack = tuner._get_skip_ops_stack() + if skip_ops is not None: + skip_ops_set = {skip_ops} if isinstance(skip_ops, str) else skip_ops + current = skip_ops_stack[-1] if skip_ops_stack else frozenset() + skip_ops_stack.append(current | frozenset(skip_ops_set)) + # Push tuning bucket overrides onto per-thread stack. Inherits from the # current top-of-stack when a parameter is not explicitly supplied. override_stack = tuner._get_override_stack() @@ -623,6 +652,8 @@ def autotune( except BaseException: if pushed: override_stack.pop() + if skip_ops is not None: + skip_ops_stack.pop() raise try: @@ -633,9 +664,11 @@ def autotune( tuner._active_tuning_contexts -= 1 tuner.is_tuning_mode = tuner._active_tuning_contexts > 0 - # Pop the override we pushed (thread-local, no lock needed). + # Pop the overrides we pushed (thread-local, no lock needed). if pushed: override_stack.pop() + if skip_ops is not None: + skip_ops_stack.pop() if autotune_enabled: logger.info("[Autotuner]: Autotuning process ends") @@ -812,6 +845,8 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): # autotune() context manager. Using threading.local ensures concurrent # autotune() contexts on different threads don't clobber each other. self._override_local = threading.local() + # Per-thread stack of frozenset[str] for skip_ops overrides. + self._skip_ops_local = threading.local() # Cache overridden TuningConfig objects to keep stable object identity # for _find_nearest_profile's LRU cache. # Two-level: WeakKeyDictionary[TuningConfig, Dict[(buckets, round_up), TuningConfig]] @@ -828,6 +863,19 @@ def _get_override_stack(self) -> List: local.stack = [] return local.stack + def _get_skip_ops_stack(self) -> List: + """Return the per-thread skip_ops stack, creating it on first access.""" + local = self._skip_ops_local + if not hasattr(local, "stack"): + local.stack = [] + return local.stack + + @property + def _effective_skip_ops(self) -> frozenset: + """Cumulative union of all skip_ops from the current thread's stack.""" + stack = self._get_skip_ops_stack() + return stack[-1] if stack else frozenset() + @property def _override_tuning_buckets(self) -> Optional[Tuple[int, ...]]: """Currently active tuning-bucket override for this thread, or ``None``.""" @@ -1077,6 +1125,20 @@ def choose_one( # Note: this is a single global lock, so multi-threaded tuning on # separate GPUs is serialized. Use multi-process (one per GPU) for # parallel multi-GPU tuning. + # Skip profiling for ops in the skip_ops set — return fallback + # immediately. The fallback runner (runners[0], tactic=-1) uses + # the op's built-in heuristic, avoiding kernel compilation. + # Checked before acquiring the lock since _effective_skip_ops is + # thread-local and does not touch shared state. + if custom_op in self._effective_skip_ops: + logger.debug( + f"[AutoTuner]: Skipping autotuning for '{custom_op}' " + f"(in skip_ops). Using fallback tactic." + ) + if not runners: + raise ValueError(f"No runners provided for op '{custom_op}'") + return runners[0], -1 + with self._lock: # Apply tuning bucket / rounding overrides from autotune() context. if self._override_tuning_buckets is not None or self._override_round_up: diff --git a/tests/autotuner/test_autotuner_core.py b/tests/autotuner/test_autotuner_core.py index ad377c2ea6..b07c26cc75 100644 --- a/tests/autotuner/test_autotuner_core.py +++ b/tests/autotuner/test_autotuner_core.py @@ -838,3 +838,180 @@ def test_choose_one_with_none_input_no_crash(): ) assert chosen_runner is runner assert tactic == -1 + + +# --------------------------------------------------------------------------- +# Tests: skip_ops +# --------------------------------------------------------------------------- + + +def test_skip_ops_prevents_profiling(monkeypatch): + """Skipped ops should return fallback immediately without profiling.""" + tuner = reset_autotuner() + runner = DummyRunner(valid_tactics=(0, 1, 2)) + inputs = [torch.empty((16, 32), dtype=torch.float32)] + config = TuningConfig() + + profile_calls = [] + + def fake_profile( + self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs + ): + profile_calls.append(tactic) + return 1.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + with autotune(tune_mode=True, skip_ops={"skip_me"}): + chosen_runner, tactic = tuner.choose_one("skip_me", [runner], config, inputs) + + assert chosen_runner is runner + assert tactic == -1 + assert len(profile_calls) == 0 + + +def test_skip_ops_does_not_affect_other_ops(monkeypatch): + """Non-skipped ops should still be profiled normally.""" + tuner = reset_autotuner() + runner = DummyRunner(valid_tactics=(0, 1)) + inputs = [torch.empty((16, 32), dtype=torch.float32)] + config = TuningConfig() + + def fake_profile( + self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs + ): + return {0: 5.0, 1: 1.0}[tactic] + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + with autotune(tune_mode=True, skip_ops={"some_other_op"}): + chosen_runner, tactic = tuner.choose_one("tune_me", [runner], config, inputs) + + assert chosen_runner is runner + assert tactic == 1 # best tactic selected via profiling + + +def test_skip_ops_nested_union(monkeypatch): + """Nested autotune contexts should union their skip_ops sets.""" + tuner = reset_autotuner() + runner = DummyRunner(valid_tactics=(0,)) + inputs = [torch.empty((4, 8), dtype=torch.float32)] + config = TuningConfig() + + def fake_profile( + self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs + ): + return 1.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + with autotune(tune_mode=True, skip_ops={"op_a"}): + # op_a should be skipped + _, tactic_a = tuner.choose_one("op_a", [runner], config, inputs) + assert tactic_a == -1 + + with autotune(tune_mode=True, skip_ops={"op_b"}): + # Both op_a and op_b should be skipped in inner context + _, tactic_a2 = tuner.choose_one("op_a", [runner], config, inputs) + assert tactic_a2 == -1 + _, tactic_b = tuner.choose_one("op_b", [runner], config, inputs) + assert tactic_b == -1 + + # After inner context exits, only op_a should still be skipped + _, tactic_a3 = tuner.choose_one("op_a", [runner], config, inputs) + assert tactic_a3 == -1 + + +def test_skip_ops_returns_first_runner(): + """Skipped ops should always return runners[0], even with multiple runners.""" + tuner = reset_autotuner() + runner_a = DummyRunner(valid_tactics=(0,)) + runner_b = DummyRunner(valid_tactics=(1,)) + inputs = [torch.empty((4, 8), dtype=torch.float32)] + config = TuningConfig() + + with autotune(tune_mode=True, skip_ops={"multi_runner_op"}): + chosen, tactic = tuner.choose_one( + "multi_runner_op", [runner_a, runner_b], config, inputs + ) + + assert chosen is runner_a + assert tactic == -1 + + +def test_skip_ops_empty_set_is_noop(monkeypatch): + """skip_ops=set() should not skip anything.""" + tuner = reset_autotuner() + runner = DummyRunner(valid_tactics=(0, 1)) + inputs = [torch.empty((16, 32), dtype=torch.float32)] + config = TuningConfig() + + def fake_profile( + self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs + ): + return {0: 5.0, 1: 1.0}[tactic] + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + with autotune(tune_mode=True, skip_ops=set()): + chosen, tactic = tuner.choose_one("should_tune", [runner], config, inputs) + + assert tactic == 1 # profiled and selected best + + +def test_skip_ops_nested_inner_op_resumes_after_exit(monkeypatch): + """op_b added by inner context should be profiled again after inner exits.""" + tuner = reset_autotuner() + runner = DummyRunner(valid_tactics=(0,)) + inputs = [torch.empty((4, 8), dtype=torch.float32)] + config = TuningConfig() + + profile_calls = [] + + def fake_profile( + self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs + ): + profile_calls.append(1) + return 1.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + with autotune(tune_mode=True, skip_ops={"op_a"}): + with autotune(tune_mode=True, skip_ops={"op_b"}): + _, tactic_b = tuner.choose_one("op_b", [runner], config, inputs) + assert tactic_b == -1 # skipped in inner + + # After inner exits, op_b should be profiled + profile_calls.clear() + _, tactic_b2 = tuner.choose_one("op_b", [runner], config, inputs) + assert len(profile_calls) > 0 # was profiled + + +def test_skip_ops_does_not_pollute_cache(): + """Skipped ops should not create entries in profiling_cache.""" + tuner = reset_autotuner() + runner = DummyRunner() + inputs = [torch.empty((4, 8), dtype=torch.float32)] + config = TuningConfig() + + cache_before = len(tuner.profiling_cache) + + with autotune(tune_mode=True, skip_ops={"no_cache_op"}): + tuner.choose_one("no_cache_op", [runner], config, inputs) + + assert len(tuner.profiling_cache) == cache_before + + +def test_skip_ops_restored_after_context(): + """skip_ops should be fully cleared after context exits.""" + tuner = reset_autotuner() + runner = DummyRunner() + inputs = [torch.empty((4, 8), dtype=torch.float32)] + config = TuningConfig() + + with autotune(tune_mode=False, skip_ops={"some_op"}): + _, tactic = tuner.choose_one("some_op", [runner], config, inputs) + assert tactic == -1 + + # After context, skip_ops should be empty — op goes through normal path + assert tuner._effective_skip_ops == frozenset()