diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 02e1acca18b..ee1408b5bfd 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -98,6 +98,7 @@ class TuningConfig: use_cold_l2_cache (bool): Whether to use cold L2 cache. This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache. Notice that not all tuning processes can benefit from this feature. + use_cuda_graph (bool): Whether to use CUDA graph for the tuning process. """ dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () @@ -211,8 +212,16 @@ def forward( """ raise NotImplementedError - def __hash__(self): - return hash(tuple(self.__dict__.values())) + def unique_id(self): + """ + Returns a tuple of the unique id of the runner. The unique id will be converted to a string for the cache key. + A common practice is to return a tuple of the runner's attributes, for example: + return (self.output_dtype, self.attribute_1, ...) + + Returns: + Any: The unique id of the runner, which can be converted to a string for the cache key. + """ + return tuple(self.__dict__.values()) @contextlib.contextmanager @@ -226,7 +235,6 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0): # if the rank-specific file exists, load it file_exists = os.path.exists(cache_path_no_ext_rank) # if the rank-specific file exists, do not enable tuning mode - tune_required = tune_required and not os.path.exists(cache_path) if file_exists: logger.info( f"[Autotuner] Loading cache from {cache_path_no_ext_rank}") @@ -259,8 +267,8 @@ class AutoTunerStatistics: cache_misses (int): Number of cache misses requiring fallback cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation - tuned_op_total_configs (Dict[str, int]): Total configurations tried per operation - tuned_op_successful_configs (Dict[str, int]): Successful configurations per operation + tuned_op_profiled_configs (Dict[str, int]): Profiled configurations per operation + tuned_op_time_cost (Dict[str, float]): Time cost per operation """ cache_misses: int = 0 cache_miss_config_collection: Dict[str, @@ -268,8 +276,8 @@ class AutoTunerStatistics: failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner, OptimizationProfile]]] = field( default_factory=dict) - tuned_op_total_configs: Dict[str, int] = field(default_factory=dict) - tuned_op_successful_configs: Dict[str, int] = field(default_factory=dict) + tuned_op_profiled_configs: Dict[str, int] = field(default_factory=dict) + tuned_op_time_cost: Dict[str, float] = field(default_factory=dict) def __str__(self) -> str: """Return a string representation of collected statistics. @@ -284,22 +292,23 @@ def __str__(self) -> str: for profile in sorted(profiles, key=str): stats_str += f" - Config: {profile}\n" - if self.tuned_op_total_configs: + if self.tuned_op_profiled_configs: stats_str += "Tuned operations:\n" - for op in sorted(self.tuned_op_total_configs.keys()): - total = self.tuned_op_total_configs[op] - successful = self.tuned_op_successful_configs.get(op, 0) - failed = len(self.failed_profiling_count.get(op, set())) - success_rate = (successful / total * 100) if total > 0 else 0 + for op in sorted(self.tuned_op_profiled_configs.keys()): + successful = self.tuned_op_profiled_configs[op] + failed = len(self.failed_profiling_count[op]) stats_str += f" {op}:\n" - stats_str += f" - Total configs tried: {total}\n" stats_str += f" - Successful configs: {successful}\n" stats_str += f" - Failed profiling count: {failed}\n" if failed > 0: stats_str += f" - Failed profiling combinations:\n" for failed_key in self.failed_profiling_count[op]: stats_str += f" - {failed_key}\n" - stats_str += f" - Success rate: {success_rate:.1f}%\n" + + if self.tuned_op_time_cost: + stats_str += "Tuned operations time cost:\n" + for op in sorted(self.tuned_op_time_cost.keys()): + stats_str += f" {op}: {self.tuned_op_time_cost[op] * 1000:.4f} milliseconds\n" return stats_str @@ -374,7 +383,7 @@ def get_cache_key( return ( custom_op, runner.__class__.__name__, - hash(runner), + str(runner.unique_id()), AutoTuner.get()._find_nearest_profile( input_shapes, tuning_config.dynamic_tensor_specs, @@ -546,6 +555,11 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000): # Last captured choose_one() contexts self._last_capture: Optional['AutoTuner.TacticsCapture'] = None + # Increase log level for AutoTuner associated logger + self._log_level_to_info = os.getenv( + "TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1' + self._debug_logger = logger.info if self._log_level_to_info else logger.debug + @classmethod def get(cls): if cls._instance is None: @@ -726,11 +740,14 @@ def choose_one( assert all([isinstance(r, TunableRunner) for r in runners]), \ "All Given runners must be subclass of TunableRunner" + tuning_start_time = time.perf_counter() profiles = self._optimization_profiles(tuning_config, inputs) - # Record the total configs to try - self.stats.tuned_op_total_configs[custom_op] = len(profiles) - + # Initialize the statistics for the custom_op + if custom_op not in self.stats.tuned_op_profiled_configs: + self.stats.tuned_op_profiled_configs[custom_op] = 0 + if custom_op not in self.stats.failed_profiling_count: + self.stats.failed_profiling_count[custom_op] = set() new_tuning_failure_occured = False for p in profiles: @@ -746,16 +763,15 @@ def choose_one( cache_key = self.profiling_cache.get_cache_key( custom_op, runners[best_runner_id], p.get_opt_shapes(), tuning_config) + + self._debug_logger( + f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." + ) # inspect call stack self.profiling_cache[cache_key] = (best_runner_id, best_tactic, min_time) - self.stats.tuned_op_successful_configs[ - custom_op] = self.stats.tuned_op_successful_configs.get( - custom_op, 0) + 1 - logger.debug( - f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." - ) + self.stats.tuned_op_profiled_configs[custom_op] += 1 else: logger.warning_once( f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. " @@ -782,6 +798,10 @@ def choose_one( _, runner_id, tactic, _ = self.profiling_cache.search_cache( custom_op, runners, input_shapes, tuning_config) + tuning_end_time = time.perf_counter() + self.stats.tuned_op_time_cost[ + custom_op] = self.stats.tuned_op_time_cost.get( + custom_op, 0) + tuning_end_time - tuning_start_time return (runners[runner_id], tactic) def _profile_runners( @@ -832,14 +852,13 @@ def _profile_runners( f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.", key=(custom_op, "warning_autotuning_profile_failure"), ) - logger.debug_once( - f"[Autotuner] Exception captured: {e}", - key=(custom_op, "debug_autotuning_exception"), - ) + (logger.info_once + if self._log_level_to_info else logger.debug_once)( + f"[Autotuner] Exception captured: {e}", + key=(custom_op, "debug_autotuning_exception"), + ) # Record the failed profiling combinations - if custom_op not in self.stats.failed_profiling_count: - self.stats.failed_profiling_count[custom_op] = set() self.stats.failed_profiling_count[custom_op].add( self.profiling_cache.get_cache_key( custom_op, runner, profile.get_opt_shapes(), @@ -957,7 +976,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int): avg_time = pure_profile(stream, self.repeat) shapes = self._get_input_sizes(inputs) - logger.debug( + self._debug_logger( f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." ) @@ -1043,7 +1062,7 @@ def _optimization_profiles( p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim( min_value, opt_value, max_value) generated_profiles.append(p) - logger.debug(f"[Autotuner] Generated profile: {p}") + self._debug_logger(f"[Autotuner] Generated profile: {p}") return generated_profiles @classmethod @@ -1159,7 +1178,7 @@ def _prepare_input_tensors_with_batches( input.element_size() if isinstance(input, torch.Tensor) else 0 for input in inputs) if one_buffer_bytes <= 0: - logger.debug( + self._debug_logger( "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." ) return [inputs] @@ -1174,7 +1193,7 @@ def _prepare_input_tensors_with_batches( list(t.clone() if isinstance(t, torch.Tensor) else t for t in inputs)) - logger.debug( + self._debug_logger( f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling" ) return inputs_list @@ -1188,16 +1207,23 @@ def reset_statistics(self) -> None: self.stats = AutoTunerStatistics() def print_profiling_cache(self): - logger.debug(f"[Autotuner] The profiling_cache entries:") - logger.debug( + self._debug_logger(f"[Autotuner] The profiling_cache entries:") + self._debug_logger( f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))" ) for key, value in self.profiling_cache.cache.items(): runner_id, tactic, min_time = value - logger.debug( + self._debug_logger( f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})" ) + self.print_statistics() + + def print_statistics(self): + self._debug_logger(f"[Autotuner] The statistics:") + for line in self.stats.__str__().split("\n"): + self._debug_logger(line) + @contextlib.contextmanager def capture(self): """Context manager for capturing execution contexts for testing. @@ -1271,7 +1297,7 @@ def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]): runner_idx = runners.index(runner) runner_tactic_list.append((runner_idx, tactic)) - logger.debug( + self._debug_logger( f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}") # Replay the contexts with given (runner, tactic) pairs diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index eb67b667b23..ee7a2763f7d 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -55,13 +55,8 @@ def __init__(self, alpha: float, output_dtype: torch.dtype): ) # rewrite the hash function because the value of self.alpha doesn't affect the tactic. - def __hash__(self): - return hash((self.output_dtype, )) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.output_dtype == other.output_dtype + def unique_id(self): + return (self.output_dtype, ) def get_valid_tactics( self, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index bed3067ee4b..f32a4aa27d2 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -93,6 +93,29 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"])) + def unique_id(self): + return ( + self.x_dtype, + self.weight_dtype, + self.output_dtype, + self.top_k, + self.tp_size, + self.tp_rank, + self.ep_size, + self.ep_rank, + self.cluster_size, + self.cluster_rank, + self.enable_alltoall, + self.use_deepseek_fp8_block_scale, + self.use_w4_group_scaling, + self.use_int8_woq_per_channel, + self.use_mxfp8_act_scaling, + self.min_latency_mode, + self.use_fused_finalize, + self.activation_type, + self.unpadded_hidden_size, + ) + def forward( self, inputs: List[torch.Tensor], @@ -316,6 +339,12 @@ def __init__( self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[ instance_key] + def unique_id(self): + return ( + self.to_userbuffers, + self.output_dtype, + ) + def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp8_rowwise_gemm_runner.get_num_configs())) @@ -398,6 +427,12 @@ def __init__( output_dtype, int(fp4_gemm_type)) self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key] + def unique_id(self): + return ( + self.to_userbuffers, + self.output_dtype, + ) + def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.fp4_gemm_runner.get_num_configs())) @@ -447,6 +482,12 @@ def __init__( self.cublaslt_runner = CublasLtFP4GemmRunner.runner_dict[instance_key] + def unique_id(self): + return hash(( + self.to_userbuffers, + self.output_dtype, + )) + def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: """Get all valid tactics (algorithms) from cuBLASLt heuristic.""" @@ -592,6 +633,15 @@ def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool, self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key] + def unique_id(self): + return ( + self.output_dtype, + self.use_deep_seek_fp8, + self.low_latency_kernel, + self.tile_size, + self.epilogue_tile_m, + ) + def forward( self, inputs: List[torch.Tensor], @@ -827,6 +877,12 @@ def __init__( self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[ instance_key] + def unique_id(self): + return ( + self.output_dtype, + self.to_userbuffers, + ) + def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: return list(range(self.weight_only_quant_gemm_runner.get_num_configs())) @@ -894,6 +950,9 @@ class FinegrainedMixedDtypeGemm(TunableRunner): def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, quant_mode: int): + self.activation_dtype = activation_dtype + self.output_dtype = output_dtype + self.quant_mode = quant_mode instance_key = (activation_dtype, output_dtype, quant_mode) if instance_key not in FinegrainedMixedDtypeGemm._runner_dict: FinegrainedMixedDtypeGemm._runner_dict[ @@ -902,6 +961,13 @@ def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[ instance_key] + def unique_id(self): + return ( + self.activation_dtype, + self.output_dtype, + self.quant_mode, + ) + def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[int]: return list( @@ -1012,6 +1078,12 @@ def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool): self.output_dtype = output_dtype self.disable_ue8m0_cast = disable_ue8m0_cast + def unique_id(self): + return ( + self.output_dtype, + self.disable_ue8m0_cast, + ) + def get_valid_tactics( self, inputs: List[torch.Tensor], diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 24b927214dc..beb2d9c623e 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -191,24 +191,10 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members - # that influence tactic validity here. e.g. we are tuning FC1 and FC2 - # so the routing type does not matter - def __hash__(self): - return hash(( - self.top_k, - self.intermediate_size, - self.local_num_experts, - )) - - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, FP4BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.local_num_experts == other.local_num_experts) + # The unique_id is used by the autotuner to get the cache key, so we hash on members + # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing type does not matter + def unique_id(self): + return (self.top_k, self.intermediate_size, self.local_num_experts) def get_runner(self): instance_key = () @@ -558,24 +544,11 @@ def __init__( FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members + # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # type does not matter - def __hash__(self): - return hash(( - self.top_k, - self.intermediate_size, - self.local_num_experts, - )) - - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, FP8BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.local_num_experts == other.local_num_experts) + def unique_id(self): + return (self.top_k, self.intermediate_size, self.local_num_experts) def get_runner(self): instance_key = () @@ -845,30 +818,18 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members + # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # type does not matter - def __hash__(self): - return hash(( + def unique_id(self): + return ( self.top_k, self.intermediate_size, self.valid_hidden_size, self.valid_intermediate_size, self.local_num_experts, self.act_type, - )) - - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, MxE4m3MxE2m1BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.valid_hidden_size == other.valid_hidden_size and - self.valid_intermediate_size == other.valid_intermediate_size - and self.local_num_experts == other.local_num_experts - and self.act_type == other.act_type) + ) def get_runner(self): instance_key = (self.act_type, True) @@ -1145,30 +1106,18 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members + # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # type does not matter - def __hash__(self): - return hash(( + def unique_id(self): + return ( self.top_k, self.intermediate_size, self.valid_hidden_size, self.valid_intermediate_size, self.local_num_experts, self.act_type, - )) - - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, E4m3MxE2m1BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.valid_hidden_size == other.valid_hidden_size and - self.valid_intermediate_size == other.valid_intermediate_size - and self.local_num_experts == other.local_num_experts - and self.act_type == other.act_type) + ) def get_runner(self): instance_key = (self.act_type, False) @@ -1425,10 +1374,10 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members + # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # type does not matter - def __hash__(self): + def unique_id(self): return hash(( self.top_k, self.intermediate_size, @@ -1438,18 +1387,6 @@ def __hash__(self): self.act_type, )) - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, Bf16MxE2m1BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.valid_hidden_size == other.valid_hidden_size and - self.valid_intermediate_size == other.valid_intermediate_size - and self.local_num_experts == other.local_num_experts - and self.act_type == other.act_type) - def get_runner(self): instance_key = (self.act_type, ) if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict: @@ -1695,26 +1632,13 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config( ) - # The hash is used by the autotuner to get the cache key, so we hash on members - # that influence tactic validity here. e.g. we are tuning FC1 and FC2 - # so the routing type does not matter - def __hash__(self): - return hash(( + def unique_id(self): + return ( self.top_k, self.intermediate_size, self.local_num_experts, self.act_type, - )) - - # __eq__ and __hash__ must agree - def __eq__(self, other): - if not isinstance(other, FP8FP4BlockScaleMoERunner): - return False - - return (self.top_k == other.top_k - and self.intermediate_size == other.intermediate_size - and self.local_num_experts == other.local_num_experts - and self.act_type == other.act_type) + ) def get_runner(self): instance_key = (self.act_type, ) diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 81e62fdf137..0768a777e0a 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -140,6 +140,10 @@ def test_autotuner_cache_basic(): with autotune(): torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w) + # This tests the logic of print_profiling_cache and print_statistics + AutoTuner.get().print_profiling_cache() + AutoTuner.get().print_statistics() + m = M * 2 while m >= 1: best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic(