From 79a7ae1cab9c2c0673c195cd37e667b6b522db91 Mon Sep 17 00:00:00 2001 From: He Jun Date: Mon, 24 Nov 2025 17:31:35 +0800 Subject: [PATCH 1/2] use global TuningConfig, to fix memory leak caused by AutoTuner LRU cache and dynamic lambda TuningConfig Pre-compute runner arg names to avoid calling inspect.signature in the loop --- flashinfer/autotuner.py | 11 +++- flashinfer/gemm/gemm_base.py | 120 +++++++++++++++++++++-------------- 2 files changed, 82 insertions(+), 49 deletions(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 9f5fb67489..a81c8f2546 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -458,6 +458,13 @@ def choose_one( # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) + # Pre-compute runner arg names to avoid calling inspect.signature in the loop + runner_arg_names_map = {} + for r in runners: + runner_arg_names_map[r] = { + param.name for param in inspect.signature(r.forward).parameters.values() + } + for p in profiles: tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, runner_id, tactic, _ = self.search_cache( @@ -470,9 +477,7 @@ def choose_one( for r_id, r in enumerate(runners): # TODO: use FakeTensor here. valid_tactics = r.get_valid_tactics(tensors, p) - runner_arg_names = { - p.name for p in inspect.signature(r.forward).parameters.values() - } + runner_arg_names = runner_arg_names_map[r] if "do_preparation" in runner_arg_names and len(valid_tactics) > 0: r(tensors, tactic=-1, do_preparation=True, **kwargs) for tac in valid_tactics: diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 251e2a4682..914067f728 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -356,6 +356,25 @@ def forward( ) +_FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 4, # out_tensor_index + -2, + lambda shapes: shapes[0][-2], + ), + ), +) + + def fp8_gemm_sm100( a: torch.Tensor, b: torch.Tensor, @@ -376,29 +395,12 @@ def fp8_gemm_sm100( runners.append(_cudnn_gemm_fp8_runner()) assert runners, "No suitable runners found" tuner = AutoTuner.get() - a_tensor_index = 0 - out_tensor_index = 4 - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] - ), - ), - ) inputs = [a, b, scale_a, scale_b, out, workspace_buffer] runner, tactic = tuner.choose_one( "fp8_gemm", runners, - tuning_config, + _FP8_GEMM_SM100_TUNING_CONFIG, inputs, ) @@ -2019,6 +2021,58 @@ def _heuristic_func_mm_fp4( return [c for c in candidate_backends if c in suitable_backends] +def _pad_up(x, y): + return ((x + y - 1) // y) * y + + +_MM_FP4_TUNING_CONFIG_8x4 = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 2, # a_scale_tensor_index + 0, + lambda shapes: _pad_up(shapes[0][0], 8), + ), + ConstraintSpec( + 6, # out_tensor_index + 0, + lambda shapes: shapes[0][0], + ), + ), +) + + +_MM_FP4_TUNING_CONFIG_128x4 = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0,), # a_tensor_index + (0,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + 2, # a_scale_tensor_index + 0, + lambda shapes: _pad_up(shapes[0][0], 128), + ), + ConstraintSpec( + 6, # out_tensor_index + 0, + lambda shapes: shapes[0][0], + ), + ), +) + + @backend_requirement( { "cudnn": _cudnn_gemm_fp4_requirement, @@ -2138,34 +2192,8 @@ def mm_fp4( # Now we have a list of runners for desired & supported backends. tuner = AutoTuner.get() - a_tensor_index = 0 - a_scale_tensor_index = 2 - out_tensor_index = 6 - - def pad_up(x, y): - return ((x + y - 1) // y) * y - - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - (a_tensor_index,), - (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), - ), - constraint_specs=( - ConstraintSpec( - a_scale_tensor_index, - 0, - lambda shapes: pad_up( - shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128 - ), - ), - ConstraintSpec( - out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0] - ), - ), + tuning_config = ( + _MM_FP4_TUNING_CONFIG_8x4 if use_8x4_sf_layout else _MM_FP4_TUNING_CONFIG_128x4 ) inputs = [ From 79a3721de4cdf0a08e81d3870b57f02b6aa34a13 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Mon, 24 Nov 2025 23:04:13 -0800 Subject: [PATCH 2/2] pre-commit --- flashinfer/gemm/gemm_base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 914067f728..15b26f02ee 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -367,9 +367,9 @@ def forward( ), constraint_specs=( ConstraintSpec( - 4, # out_tensor_index + 4, # out_tensor_index -2, - lambda shapes: shapes[0][-2], + lambda shapes: shapes[0][-2], ), ), ) @@ -2036,14 +2036,14 @@ def _pad_up(x, y): ), constraint_specs=( ConstraintSpec( - 2, # a_scale_tensor_index + 2, # a_scale_tensor_index 0, - lambda shapes: _pad_up(shapes[0][0], 8), + lambda shapes: _pad_up(shapes[0][0], 8), ), ConstraintSpec( - 6, # out_tensor_index + 6, # out_tensor_index 0, - lambda shapes: shapes[0][0], + lambda shapes: shapes[0][0], ), ), ) @@ -2060,14 +2060,14 @@ def _pad_up(x, y): ), constraint_specs=( ConstraintSpec( - 2, # a_scale_tensor_index + 2, # a_scale_tensor_index 0, - lambda shapes: _pad_up(shapes[0][0], 128), + lambda shapes: _pad_up(shapes[0][0], 128), ), ConstraintSpec( - 6, # out_tensor_index + 6, # out_tensor_index 0, - lambda shapes: shapes[0][0], + lambda shapes: shapes[0][0], ), ), )