Skip to content

Commit da0fe55

Browse files
hyukndominicshanshan
authored andcommitted
[TRTLLM-4501][feat] Add input tensor pre-hook function API for the tuning process. (NVIDIA#6924)
Some tunable ops require a more realistic data distribution, for instance, a shape-associated tensor. Thus, a customizable pre-hook function can be declared in the tuning config to modify the input tensor before the tuning process. Signed-off-by: Yukun He <[email protected]>
1 parent 105d546 commit da0fe55

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,13 @@ class TuningConfig:
9292
any value is provided to the choose_one function, the input tensor will be saturated
9393
with the provided value.
9494
If not provided, the autotuner will not consider the max num tokens.
95+
inputs_pre_hook (Callable): A function that takes a list of input tensors, returns a list of modified input tensors.
96+
It is called before the input tensors are prepared for the tuning process to match the real data distribution.
9597
"""
9698
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
9799
constraint_specs: Tuple[ConstraintSpec, ...] = ()
98100
tune_max_num_tokens: int = None
101+
inputs_pre_hook: Callable = None
99102

100103

101104
@dataclass(unsafe_hash=True)
@@ -662,6 +665,9 @@ def _profile_runners(
662665
min_time = float('inf')
663666
has_tuning_failure_occured = False
664667
best_runner_id, best_tactic = None, None
668+
# If the inputs_pre_hook is provided, it will be called before profiling.
669+
if tuning_config.inputs_pre_hook is not None:
670+
input_tensors = tuning_config.inputs_pre_hook(input_tensors)
665671
for runner_id, runner in enumerate(runners):
666672
# TODO: use FakeTensor here.
667673
runner_arg_names = {

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class MoERunner(TunableRunner):
2929
runner_dict = dict()
3030
tuning_config = TuningConfig(
3131
dynamic_tensor_specs=(DynamicTensorSpec(
32-
0, 0, get_last_power_of_2_num_tokens_buckets(8192),
33-
lambda x: min(last_positive_power_of_2(x), 8192)), ),
32+
0, 0, get_last_power_of_2_num_tokens_buckets,
33+
last_positive_power_of_2), ),
3434
tune_max_num_tokens=8192,
3535
)
3636

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,24 @@ def test_multiple_dynamic_shapes_cache():
322322
f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"
323323

324324

325-
class GemmRunnerWithTacticConfigs(TunableRunner):
325+
class GemmRunnerComplexTuningConfigs(TunableRunner):
326326
valid_tactic_ids = [-1, 0, 1]
327+
tune_max_num_tokens = 32
328+
329+
def get_valid_tactics(
330+
self,
331+
inputs: List[FakeTensor],
332+
profile: OptimizationProfile,
333+
**kwargs,
334+
) -> List[Dict[str, int]]:
335+
# During the tuning process, we verify if the tuning config behaves as expected
336+
337+
assert inputs[0].shape[0] <= self.tune_max_num_tokens, \
338+
f"Input shape {inputs[0].shape[0]} is larger than the max num tokens {self.tune_max_num_tokens}"
339+
340+
assert inputs[0][-1, 0] == inputs[0].shape[0], \
341+
f"Input shape {inputs[0].shape[0]} is not set through the pre_hook correctly"
327342

328-
def get_valid_tactics(self, inputs: List[FakeTensor],
329-
profile: OptimizationProfile,
330-
**kwargs) -> List[Dict[str, int]]:
331343
# The simulated delay is not deterministic, so we need to return specific tactics here
332344
return [{
333345
"block_size": block_size,
@@ -350,12 +362,30 @@ def forward(
350362
assert tactic_id in self.valid_tactic_ids
351363
return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs)
352364

365+
@staticmethod
366+
def inputs_pre_hook(inputs: List[torch.Tensor]):
367+
# always set the first element to bo iota in x
368+
x, w = inputs
369+
x_hooked = torch.zeros_like(x)
370+
x_hooked[-1, 0] = x.shape[0]
371+
return [x_hooked, w]
372+
353373

354-
def test_autotuner_tactic_configs():
355-
runner_0 = GemmRunnerWithTacticConfigs()
374+
def test_autotuner_tuning_configs():
375+
runner_0 = GemmRunnerComplexTuningConfigs()
356376
runners = [runner_0]
357377
x, w = torch.randn(64, 64), torch.randn(64, 128)
358-
tuning_config = TuningConfig()
378+
tuning_config = TuningConfig(
379+
dynamic_tensor_specs=(DynamicTensorSpec(
380+
input_idx=0,
381+
dim_idx=0,
382+
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
383+
map_to_tuning_buckets=next_positive_power_of_2,
384+
), ),
385+
# Test if the number of tuning tokens is clipped to 32
386+
tune_max_num_tokens=GemmRunnerComplexTuningConfigs.tune_max_num_tokens,
387+
inputs_pre_hook=GemmRunnerComplexTuningConfigs.inputs_pre_hook,
388+
)
359389
with autotune():
360390
tuner = AutoTuner.get()
361391
runner, tactic = tuner.choose_one("test_autotuner_tactic_configs",

0 commit comments

Comments
 (0)