From d0bd9d4ebd0eb7f4c7e5dc0889faf0d7873c102a Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:56:39 +0000 Subject: [PATCH 1/9] [None][feat] Achieve cold L2 for each kernel profing repeated in the tactic tuning. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 36 ++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index aa1b250b3a1..c0fac3f5511 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -525,10 +525,14 @@ def _profile_single_kernel( to get an average execution time. Stream synchronization and delays are used to ensure accurate timing. """ + + tensor_lists = self._prepare_input_tensors_with_batches(inputs) + stream = torch.cuda.current_stream() # warm up, no timing + # always use the last batch for warmup for _ in range(self.warmup): - runner(inputs, tactic=tactic, **kwargs) + runner(tensor_lists[-1], tactic=tactic, **kwargs) stream.synchronize() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. @@ -539,14 +543,14 @@ def _profile_single_kernel( end = torch.cuda.Event(enable_timing=True) start.record(stream=stream) - for _ in range(self.repeat): - runner(inputs, tactic=tactic, **kwargs) + for r in range(self.repeat): + runner(tensor_lists[r], tactic=tactic, **kwargs) end.record(stream=stream) stream.synchronize() avg_time = start.elapsed_time(end) / self.repeat - shapes = self._get_input_sizes(inputs) + shapes = self._get_input_sizes(tensor_lists[-1]) logger.debug( f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." ) @@ -733,6 +737,23 @@ def _prepare_input_tensors( tensors.append(tensor) return tensors + def _prepare_input_tensors_with_batches( + self, + inputs: List[torch.Tensor], + ) -> List[List[torch.Tensor]]: + if not self._all_tensors_smaller_than_l2_cache(inputs): + print(f"[Autotuner] All tensors are larger than L2 cache, use the same tensor for profiling") + return [inputs] * (self.repeat + 1) + + inputs_list = [inputs] + # The last batch is for warmup + for _ in range(self.repeat): + inputs_list.append(list(t.clone() for t in inputs)) + + print(f"[Autotuner] All tensors are smaller than L2 cache, use {len(inputs_list)} different tensors for profiling") + return inputs_list + + def clear_cache(self) -> None: """Clear the profiling cache.""" self.profiling_cache.clear() @@ -750,3 +771,10 @@ def print_profiling_cache(self): runner_id, tactic, profile = value logger.debug( f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})") + + def _all_tensors_smaller_than_l2_cache(self, inputs: List[torch.Tensor]) -> bool: + return all(input.numel() * input.element_size() <= self._get_l2_cache_size_in_bytes() for input in inputs) + + def _get_l2_cache_size_in_bytes(self) -> int: + # TODO: Only consider Blackwell L2 cache + return 96 * 1024 * 1024 \ No newline at end of file From 4dbcaad6dfcc48267f1cb81c8f5a25dca4d3a542 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sat, 6 Sep 2025 04:42:27 -0700 Subject: [PATCH 2/9] improve cold L2 logic Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 89 ++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index c0fac3f5511..1d8ed334b00 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch +from cuda.bindings import driver from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger @@ -526,13 +527,26 @@ def _profile_single_kernel( are used to ensure accurate timing. """ - tensor_lists = self._prepare_input_tensors_with_batches(inputs) + use_cold_l2 = True + + if use_cold_l2: + tensor_lists, num_buffers = self._prepare_input_tensors_with_batches( + inputs) + buffer_idx = 0 + else: + tensor_lists = [inputs] + num_buffers = 1 + buffer_idx = 0 stream = torch.cuda.current_stream() # warm up, no timing # always use the last batch for warmup for _ in range(self.warmup): - runner(tensor_lists[-1], tactic=tactic, **kwargs) + # runner(tensor_lists[-1], tactic=tactic, **kwargs) + runner(tensor_lists[buffer_idx % num_buffers], + tactic=tactic, + **kwargs) + buffer_idx += 1 stream.synchronize() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. @@ -543,8 +557,11 @@ def _profile_single_kernel( end = torch.cuda.Event(enable_timing=True) start.record(stream=stream) - for r in range(self.repeat): - runner(tensor_lists[r], tactic=tactic, **kwargs) + for _ in range(self.repeat): + runner(tensor_lists[buffer_idx % num_buffers], + tactic=tactic, + **kwargs) + buffer_idx += 1 end.record(stream=stream) stream.synchronize() @@ -738,21 +755,29 @@ def _prepare_input_tensors( return tensors def _prepare_input_tensors_with_batches( - self, - inputs: List[torch.Tensor], - ) -> List[List[torch.Tensor]]: - if not self._all_tensors_smaller_than_l2_cache(inputs): - print(f"[Autotuner] All tensors are larger than L2 cache, use the same tensor for profiling") - return [inputs] * (self.repeat + 1) + self, + inputs: List[torch.Tensor], + ) -> Tuple[List[List[torch.Tensor]], int]: + # TODO: only consider tensor parameter? + one_buffer_bytes = sum( + input.numel() * + input.element_size() if isinstance(input, torch.Tensor) else 0 + for input in inputs) + num_buffers = ceil(self._get_l2_cache_size_in_bytes() / + one_buffer_bytes) + num_buffers = min(num_buffers, self.repeat) inputs_list = [inputs] # The last batch is for warmup - for _ in range(self.repeat): - inputs_list.append(list(t.clone() for t in inputs)) - - print(f"[Autotuner] All tensors are smaller than L2 cache, use {len(inputs_list)} different tensors for profiling") - return inputs_list + for _ in range(num_buffers - 1): + inputs_list.append( + list(t.clone() if isinstance(t, torch.Tensor) else t + for t in inputs)) + logger.debug( + f"[Autotuner] To cold L2 cache, use {num_buffers} different tensors for profiling" + ) + return inputs_list, num_buffers def clear_cache(self) -> None: """Clear the profiling cache.""" @@ -772,9 +797,31 @@ def print_profiling_cache(self): logger.debug( f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})") - def _all_tensors_smaller_than_l2_cache(self, inputs: List[torch.Tensor]) -> bool: - return all(input.numel() * input.element_size() <= self._get_l2_cache_size_in_bytes() for input in inputs) - - def _get_l2_cache_size_in_bytes(self) -> int: - # TODO: Only consider Blackwell L2 cache - return 96 * 1024 * 1024 \ No newline at end of file + def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int: + device = self._checkCudaErrors(driver.cuDeviceGet(device_id)) + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, + device, + )) + + def _checkCudaErrors(self, result) -> None: + if result[0].value: + raise RuntimeError("CUDA error code={}({})".format( + result[0].value, self._cudaGetErrorEnum(result[0]))) + # CUDA APIs always return the status as the first element of the result tuple + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + def _cudaGetErrorEnum(error) -> str: + if isinstance(error, driver.CUresult): + err, name = driver.cuGetErrorName(error) + return name if err == driver.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) From 8c901d56da5bb85a049b8f38256852abb37923dc Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sat, 6 Sep 2025 04:48:42 -0700 Subject: [PATCH 3/9] minor Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 1d8ed334b00..e7c72e96aff 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -540,9 +540,7 @@ def _profile_single_kernel( stream = torch.cuda.current_stream() # warm up, no timing - # always use the last batch for warmup for _ in range(self.warmup): - # runner(tensor_lists[-1], tactic=tactic, **kwargs) runner(tensor_lists[buffer_idx % num_buffers], tactic=tactic, **kwargs) From cd5a2b649ef776577a6309e2338534686d0a7c73 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:01:42 -0700 Subject: [PATCH 4/9] minor Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index e7c72e96aff..2793e1ee75a 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache +from math import ceil from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch @@ -756,17 +757,20 @@ def _prepare_input_tensors_with_batches( self, inputs: List[torch.Tensor], ) -> Tuple[List[List[torch.Tensor]], int]: - # TODO: only consider tensor parameter? one_buffer_bytes = sum( input.numel() * input.element_size() if isinstance(input, torch.Tensor) else 0 for input in inputs) + if one_buffer_bytes <= 0: + logger.debug( + "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." + ) + return [inputs], 1 num_buffers = ceil(self._get_l2_cache_size_in_bytes() / one_buffer_bytes) num_buffers = min(num_buffers, self.repeat) inputs_list = [inputs] - # The last batch is for warmup for _ in range(num_buffers - 1): inputs_list.append( list(t.clone() if isinstance(t, torch.Tensor) else t From f8e9a90f31838d53053e50cdcff20941241a86b8 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:06:34 -0700 Subject: [PATCH 5/9] fix Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 2793e1ee75a..09132960aaf 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -819,7 +819,8 @@ def _checkCudaErrors(self, result) -> None: else: return result[1:] - def _cudaGetErrorEnum(error) -> str: + def _cudaGetErrorEnum(self, error) -> str: + from cuda.bindings import nvrtc if isinstance(error, driver.CUresult): err, name = driver.cuGetErrorName(error) return name if err == driver.CUresult.CUDA_SUCCESS else "" From fbcf73f578b15500507f582ad0b4a196e4c0d63e Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:13:08 -0700 Subject: [PATCH 6/9] fix Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 09132960aaf..d5de96ba768 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -807,10 +807,12 @@ def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int: device, )) - def _checkCudaErrors(self, result) -> None: - if result[0].value: - raise RuntimeError("CUDA error code={}({})".format( - result[0].value, self._cudaGetErrorEnum(result[0]))) + def _checkCudaErrors(self, result) -> Any: + status = result[0] + if status != driver.CUresult.CUDA_SUCCESS: + code = getattr(status, "value", status) + raise RuntimeError( + f"CUDA error code={code}({self._cudaGetErrorEnum(status)})") # CUDA APIs always return the status as the first element of the result tuple if len(result) == 1: return None From b026410823b0b445f4486629e00dacae859a14cf Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 8 Sep 2025 22:52:32 -0700 Subject: [PATCH 7/9] fix. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index d5de96ba768..84c49e0adb6 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -5,7 +5,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache -from math import ceil from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch @@ -766,9 +765,10 @@ def _prepare_input_tensors_with_batches( "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." ) return [inputs], 1 - num_buffers = ceil(self._get_l2_cache_size_in_bytes() / - one_buffer_bytes) - num_buffers = min(num_buffers, self.repeat) + num_buffers = (self._get_l2_cache_size_in_bytes() * + 3) // one_buffer_bytes + 1 + num_iterations = self.warmup + self.repeat + num_buffers = num_iterations if num_iterations < num_buffers else num_buffers inputs_list = [inputs] for _ in range(num_buffers - 1): From 1cd7b557d465d99f50d4d73110016cf813c99d70 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Tue, 9 Sep 2025 06:39:21 +0000 Subject: [PATCH 8/9] Add use_cold_l2_cache arg in TuningConfig and several minor improvements. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 66 ++++++++++++++++---------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 84c49e0adb6..b8fe8914061 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -87,10 +87,12 @@ class TuningConfig: any value is provided to the choose_one function, the input tensor will be saturated with the provided value. If not provided, the autotuner will not consider the max num tokens. + use_cold_l2_cache (bool): If true, use cold L2 cache for profiling. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () tune_max_num_tokens: int = None + use_cold_l2_cache: bool = False @dataclass(unsafe_hash=True) @@ -467,7 +469,7 @@ def _profile_runners( for tac in valid_tactics: try: time_measured = self._profile_single_kernel( - runner, input_tensors, tac, **kwargs) + runner, input_tensors, tac, tuning_config, **kwargs) except Exception as e: # Handle None tensors for optional inputs shapes = self._get_input_sizes(input_tensors) @@ -509,6 +511,7 @@ def _profile_single_kernel( runner: TunableRunner, inputs: List[torch.Tensor], tactic: Any, + tuning_config: TuningConfig, **kwargs, ) -> float: """Profile a single kernel implementation for performance measurement. @@ -526,25 +529,17 @@ def _profile_single_kernel( to get an average execution time. Stream synchronization and delays are used to ensure accurate timing. """ - - use_cold_l2 = True - - if use_cold_l2: - tensor_lists, num_buffers = self._prepare_input_tensors_with_batches( - inputs) - buffer_idx = 0 - else: - tensor_lists = [inputs] - num_buffers = 1 - buffer_idx = 0 + input_tensor_baches = self._prepare_input_tensors_with_batches( + inputs, tuning_config) stream = torch.cuda.current_stream() # warm up, no timing for _ in range(self.warmup): - runner(tensor_lists[buffer_idx % num_buffers], - tactic=tactic, - **kwargs) - buffer_idx += 1 + runner( + input_tensor_baches[-1], + tactic=tactic, + **kwargs, + ) stream.synchronize() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. @@ -555,17 +550,18 @@ def _profile_single_kernel( end = torch.cuda.Event(enable_timing=True) start.record(stream=stream) - for _ in range(self.repeat): - runner(tensor_lists[buffer_idx % num_buffers], - tactic=tactic, - **kwargs) - buffer_idx += 1 + for r in range(self.repeat): + runner( + input_tensor_baches[r % len(input_tensor_baches)], + tactic=tactic, + **kwargs, + ) end.record(stream=stream) stream.synchronize() avg_time = start.elapsed_time(end) / self.repeat - shapes = self._get_input_sizes(tensor_lists[-1]) + shapes = self._get_input_sizes(input_tensor_baches[-1]) logger.debug( f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." ) @@ -755,31 +751,35 @@ def _prepare_input_tensors( def _prepare_input_tensors_with_batches( self, inputs: List[torch.Tensor], - ) -> Tuple[List[List[torch.Tensor]], int]: + tuning_config: TuningConfig, + ) -> List[List[torch.Tensor]]: + if not tuning_config.use_cold_l2_cache: + return [inputs] + one_buffer_bytes = sum( input.numel() * input.element_size() if isinstance(input, torch.Tensor) else 0 for input in inputs) if one_buffer_bytes <= 0: - logger.debug( + logger.info( "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." ) - return [inputs], 1 - num_buffers = (self._get_l2_cache_size_in_bytes() * - 3) // one_buffer_bytes + 1 - num_iterations = self.warmup + self.repeat - num_buffers = num_iterations if num_iterations < num_buffers else num_buffers + return [inputs] + + num_buffers = self._get_l2_cache_size_in_bytes( + ) * 3 // one_buffer_bytes + 1 + num_buffers = min(num_buffers, self.repeat + 1) inputs_list = [inputs] for _ in range(num_buffers - 1): inputs_list.append( - list(t.clone() if isinstance(t, torch.Tensor) else t + list(t.clone() if isinstance(t, torch.Tensor) else tå for t in inputs)) - logger.debug( - f"[Autotuner] To cold L2 cache, use {num_buffers} different tensors for profiling" + logger.info( + f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling" ) - return inputs_list, num_buffers + return inputs_list def clear_cache(self) -> None: """Clear the profiling cache.""" From db11ba4171b68f122b3e52f0682530c995936000 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 9 Sep 2025 01:13:57 -0700 Subject: [PATCH 9/9] fix merge error. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index b8fe8914061..927850b248c 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -773,7 +773,7 @@ def _prepare_input_tensors_with_batches( inputs_list = [inputs] for _ in range(num_buffers - 1): inputs_list.append( - list(t.clone() if isinstance(t, torch.Tensor) else tå + list(t.clone() if isinstance(t, torch.Tensor) else t for t in inputs)) logger.info(