diff --git a/vllm/env_override.py b/vllm/env_override.py index 78270c2bee37..aed826a87e42 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -758,3 +758,51 @@ def _exec_then_patch(module): _patch_cpp_indirect_assert_if_needed() + +# =================================================== +# Triton Autotuner disable +# =================================================== +# Replace Autotuner.run so it always picks configs[0] and skips benchmarking. +# Used to eliminate autotuning variability when measuring kernel perf. +# Gated on VLLM_TRITON_FORCE_FIRST_CONFIG=1 so it is opt-in. +from vllm.triton_utils import HAS_TRITON # noqa: E402 + + +def _disable_triton_autotuner(): + if not HAS_TRITON: + return + if os.environ.get("VLLM_TRITON_FORCE_FIRST_CONFIG", "0").strip().lower() not in ( + "1", + "true", + ): + return + import importlib + + Autotuner = importlib.import_module("triton.runtime.autotuner").Autotuner + seen_kernels: set[str] = set() + + def _run_first_config(self, *args, **kwargs): + config = self.configs[0] + self.best_config = config + kernel_name = getattr(self.fn, "__name__", repr(self.fn)) + if kernel_name not in seen_kernels: + seen_kernels.add(kernel_name) + logger.info( + "[triton-autotune-disabled] kernel=%s configs=%d picked=%s", + kernel_name, + len(self.configs), + config, + ) + if config.pre_hook is not None: + full_nargs = { + **dict(zip(self.arg_names, args)), + **kwargs, + **config.all_kwargs(), + } + config.pre_hook(full_nargs) + return self.fn.run(*args, **kwargs, **config.all_kwargs()) + + Autotuner.run = _run_first_config + + +_disable_triton_autotuner()