-
Notifications
You must be signed in to change notification settings - Fork 1k
feat(autotuner): enable per-op autotune bypass for faster framework warmup #3396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -838,3 +838,180 @@ def test_choose_one_with_none_input_no_crash(): | |
| ) | ||
| assert chosen_runner is runner | ||
| assert tactic == -1 | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Tests: skip_ops | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_skip_ops_prevents_profiling(monkeypatch): | ||
| """Skipped ops should return fallback immediately without profiling.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner(valid_tactics=(0, 1, 2)) | ||
| inputs = [torch.empty((16, 32), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| profile_calls = [] | ||
|
|
||
| def fake_profile( | ||
| self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs | ||
| ): | ||
| profile_calls.append(tactic) | ||
| return 1.0 | ||
|
|
||
| monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"skip_me"}): | ||
| chosen_runner, tactic = tuner.choose_one("skip_me", [runner], config, inputs) | ||
|
|
||
| assert chosen_runner is runner | ||
| assert tactic == -1 | ||
| assert len(profile_calls) == 0 | ||
|
|
||
|
|
||
| def test_skip_ops_does_not_affect_other_ops(monkeypatch): | ||
| """Non-skipped ops should still be profiled normally.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner(valid_tactics=(0, 1)) | ||
| inputs = [torch.empty((16, 32), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| def fake_profile( | ||
| self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs | ||
| ): | ||
| return {0: 5.0, 1: 1.0}[tactic] | ||
|
|
||
| monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"some_other_op"}): | ||
| chosen_runner, tactic = tuner.choose_one("tune_me", [runner], config, inputs) | ||
|
|
||
| assert chosen_runner is runner | ||
| assert tactic == 1 # best tactic selected via profiling | ||
|
|
||
|
|
||
| def test_skip_ops_nested_union(monkeypatch): | ||
| """Nested autotune contexts should union their skip_ops sets.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner(valid_tactics=(0,)) | ||
| inputs = [torch.empty((4, 8), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| def fake_profile( | ||
| self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs | ||
| ): | ||
| return 1.0 | ||
|
|
||
| monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"op_a"}): | ||
| # op_a should be skipped | ||
| _, tactic_a = tuner.choose_one("op_a", [runner], config, inputs) | ||
| assert tactic_a == -1 | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"op_b"}): | ||
| # Both op_a and op_b should be skipped in inner context | ||
| _, tactic_a2 = tuner.choose_one("op_a", [runner], config, inputs) | ||
| assert tactic_a2 == -1 | ||
| _, tactic_b = tuner.choose_one("op_b", [runner], config, inputs) | ||
| assert tactic_b == -1 | ||
|
|
||
| # After inner context exits, only op_a should still be skipped | ||
| _, tactic_a3 = tuner.choose_one("op_a", [runner], config, inputs) | ||
| assert tactic_a3 == -1 | ||
|
|
||
|
|
||
| def test_skip_ops_returns_first_runner(): | ||
| """Skipped ops should always return runners[0], even with multiple runners.""" | ||
| tuner = reset_autotuner() | ||
| runner_a = DummyRunner(valid_tactics=(0,)) | ||
| runner_b = DummyRunner(valid_tactics=(1,)) | ||
| inputs = [torch.empty((4, 8), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"multi_runner_op"}): | ||
| chosen, tactic = tuner.choose_one( | ||
| "multi_runner_op", [runner_a, runner_b], config, inputs | ||
| ) | ||
|
|
||
| assert chosen is runner_a | ||
| assert tactic == -1 | ||
|
|
||
|
|
||
| def test_skip_ops_empty_set_is_noop(monkeypatch): | ||
| """skip_ops=set() should not skip anything.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner(valid_tactics=(0, 1)) | ||
| inputs = [torch.empty((16, 32), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| def fake_profile( | ||
| self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs | ||
| ): | ||
| return {0: 5.0, 1: 1.0}[tactic] | ||
|
|
||
| monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops=set()): | ||
| chosen, tactic = tuner.choose_one("should_tune", [runner], config, inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix Ruff RUF059 by discarding unused unpacked values. Line 957 ( Suggested cleanup- with autotune(tune_mode=True, skip_ops=set()):
- chosen, tactic = tuner.choose_one("should_tune", [runner], config, inputs)
+ with autotune(tune_mode=True, skip_ops=set()):
+ _, tactic = tuner.choose_one("should_tune", [runner], config, inputs)
...
- _, tactic_b2 = tuner.choose_one("op_b", [runner], config, inputs)
+ _, _ = tuner.choose_one("op_b", [runner], config, inputs)Also applies to: 986-986 🧰 Tools🪛 Ruff (0.15.14)[warning] 957-957: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) 🤖 Prompt for AI Agents |
||
|
|
||
| assert tactic == 1 # profiled and selected best | ||
|
|
||
|
|
||
| def test_skip_ops_nested_inner_op_resumes_after_exit(monkeypatch): | ||
| """op_b added by inner context should be profiled again after inner exits.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner(valid_tactics=(0,)) | ||
| inputs = [torch.empty((4, 8), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| profile_calls = [] | ||
|
|
||
| def fake_profile( | ||
| self, runner_obj, prof_inputs, tactic, tuning_config=None, **kwargs | ||
| ): | ||
| profile_calls.append(1) | ||
| return 1.0 | ||
|
|
||
| monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"op_a"}): | ||
| with autotune(tune_mode=True, skip_ops={"op_b"}): | ||
| _, tactic_b = tuner.choose_one("op_b", [runner], config, inputs) | ||
| assert tactic_b == -1 # skipped in inner | ||
|
|
||
| # After inner exits, op_b should be profiled | ||
| profile_calls.clear() | ||
| _, tactic_b2 = tuner.choose_one("op_b", [runner], config, inputs) | ||
| assert len(profile_calls) > 0 # was profiled | ||
|
|
||
|
|
||
| def test_skip_ops_does_not_pollute_cache(): | ||
| """Skipped ops should not create entries in profiling_cache.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner() | ||
| inputs = [torch.empty((4, 8), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| cache_before = len(tuner.profiling_cache) | ||
|
|
||
| with autotune(tune_mode=True, skip_ops={"no_cache_op"}): | ||
| tuner.choose_one("no_cache_op", [runner], config, inputs) | ||
|
|
||
| assert len(tuner.profiling_cache) == cache_before | ||
|
|
||
|
|
||
| def test_skip_ops_restored_after_context(): | ||
| """skip_ops should be fully cleared after context exits.""" | ||
| tuner = reset_autotuner() | ||
| runner = DummyRunner() | ||
| inputs = [torch.empty((4, 8), dtype=torch.float32)] | ||
| config = TuningConfig() | ||
|
|
||
| with autotune(tune_mode=False, skip_ops={"some_op"}): | ||
| _, tactic = tuner.choose_one("some_op", [runner], config, inputs) | ||
| assert tactic == -1 | ||
|
|
||
| # After context, skip_ops should be empty — op goes through normal path | ||
| assert tuner._effective_skip_ops == frozenset() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate
runnersonce at method entry, not only in the skip path.Line 1136 adds an empty-list guard only for skipped ops. For non-skipped ops, Line 1152 still does
runners[runner_id], sorunners=[]can raiseIndexErrorin inference mode.Suggested fix
def choose_one( self, custom_op: str, runners: List[TunableRunner], tuning_config: TuningConfig, inputs: List[torch.Tensor], **kwargs, ) -> Tuple[TunableRunner, int]: + if not runners: + raise ValueError(f"No runners provided for op '{custom_op}'") + if not all(isinstance(r, TunableRunner) for r in runners): + raise TypeError("All given runners must be subclasses of TunableRunner") + # Skip profiling for ops in the skip_ops set — return fallback # immediately. if custom_op in self._effective_skip_ops: logger.debug( f"[AutoTuner]: Skipping autotuning for '{custom_op}' " f"(in skip_ops). Using fallback tactic." ) - if not runners: - raise ValueError(f"No runners provided for op '{custom_op}'") return runners[0], -1 ... - assert len(runners) > 0, "At least one runner is required" - assert all([isinstance(r, TunableRunner) for r in runners]), ( - "All Given runners must be subclass of TunableRunner" - )🤖 Prompt for AI Agents