Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 65 additions & 39 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class TuningConfig:
use_cold_l2_cache (bool): Whether to use cold L2 cache.
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
Notice that not all tuning processes can benefit from this feature.
use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
"""
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
constraint_specs: Tuple[ConstraintSpec, ...] = ()
Expand Down Expand Up @@ -211,8 +212,16 @@ def forward(
"""
raise NotImplementedError

def __hash__(self):
return hash(tuple(self.__dict__.values()))
def unique_id(self):
"""
Returns a tuple of the unique id of the runner. The unique id will be converted to a string for the cache key.
A common practice is to return a tuple of the runner's attributes, for example:
return (self.output_dtype, self.attribute_1, ...)

Returns:
Any: The unique id of the runner, which can be converted to a string for the cache key.
"""
return tuple(self.__dict__.values())


@contextlib.contextmanager
Expand All @@ -226,7 +235,6 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
# if the rank-specific file exists, load it
file_exists = os.path.exists(cache_path_no_ext_rank)
# if the rank-specific file exists, do not enable tuning mode
tune_required = tune_required and not os.path.exists(cache_path)
if file_exists:
logger.info(
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
Expand Down Expand Up @@ -259,17 +267,17 @@ class AutoTunerStatistics:
cache_misses (int): Number of cache misses requiring fallback
cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses
failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation
tuned_op_total_configs (Dict[str, int]): Total configurations tried per operation
tuned_op_successful_configs (Dict[str, int]): Successful configurations per operation
tuned_op_profiled_configs (Dict[str, int]): Profiled configurations per operation
tuned_op_time_cost (Dict[str, float]): Time cost per operation
"""
cache_misses: int = 0
cache_miss_config_collection: Dict[str,
Set[tuple]] = field(default_factory=dict)
failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner,
OptimizationProfile]]] = field(
default_factory=dict)
tuned_op_total_configs: Dict[str, int] = field(default_factory=dict)
tuned_op_successful_configs: Dict[str, int] = field(default_factory=dict)
tuned_op_profiled_configs: Dict[str, int] = field(default_factory=dict)
tuned_op_time_cost: Dict[str, float] = field(default_factory=dict)

def __str__(self) -> str:
"""Return a string representation of collected statistics.
Expand All @@ -284,22 +292,23 @@ def __str__(self) -> str:
for profile in sorted(profiles, key=str):
stats_str += f" - Config: {profile}\n"

if self.tuned_op_total_configs:
if self.tuned_op_profiled_configs:
stats_str += "Tuned operations:\n"
for op in sorted(self.tuned_op_total_configs.keys()):
total = self.tuned_op_total_configs[op]
successful = self.tuned_op_successful_configs.get(op, 0)
failed = len(self.failed_profiling_count.get(op, set()))
success_rate = (successful / total * 100) if total > 0 else 0
for op in sorted(self.tuned_op_profiled_configs.keys()):
successful = self.tuned_op_profiled_configs[op]
failed = len(self.failed_profiling_count[op])
stats_str += f" {op}:\n"
stats_str += f" - Total configs tried: {total}\n"
stats_str += f" - Successful configs: {successful}\n"
stats_str += f" - Failed profiling count: {failed}\n"
if failed > 0:
stats_str += f" - Failed profiling combinations:\n"
for failed_key in self.failed_profiling_count[op]:
stats_str += f" - {failed_key}\n"
stats_str += f" - Success rate: {success_rate:.1f}%\n"

if self.tuned_op_time_cost:
stats_str += "Tuned operations time cost:\n"
for op in sorted(self.tuned_op_time_cost.keys()):
stats_str += f" {op}: {self.tuned_op_time_cost[op] * 1000:.4f} milliseconds\n"

return stats_str

Expand Down Expand Up @@ -374,7 +383,7 @@ def get_cache_key(
return (
custom_op,
runner.__class__.__name__,
hash(runner),
str(runner.unique_id()),
AutoTuner.get()._find_nearest_profile(
input_shapes,
tuning_config.dynamic_tensor_specs,
Expand Down Expand Up @@ -546,6 +555,11 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
# Last captured choose_one() contexts
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None

# Increase log level for AutoTuner associated logger
self._log_level_to_info = os.getenv(
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
self._debug_logger = logger.info if self._log_level_to_info else logger.debug

@classmethod
def get(cls):
if cls._instance is None:
Expand Down Expand Up @@ -726,11 +740,14 @@ def choose_one(
assert all([isinstance(r, TunableRunner) for r in runners]), \
"All Given runners must be subclass of TunableRunner"

tuning_start_time = time.perf_counter()
profiles = self._optimization_profiles(tuning_config, inputs)

# Record the total configs to try
self.stats.tuned_op_total_configs[custom_op] = len(profiles)

# Initialize the statistics for the custom_op
if custom_op not in self.stats.tuned_op_profiled_configs:
self.stats.tuned_op_profiled_configs[custom_op] = 0
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
new_tuning_failure_occured = False

for p in profiles:
Expand All @@ -746,16 +763,15 @@ def choose_one(
cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], p.get_opt_shapes(),
tuning_config)

self._debug_logger(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
# inspect call stack
self.profiling_cache[cache_key] = (best_runner_id,
best_tactic, min_time)

self.stats.tuned_op_successful_configs[
custom_op] = self.stats.tuned_op_successful_configs.get(
custom_op, 0) + 1
logger.debug(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
self.stats.tuned_op_profiled_configs[custom_op] += 1
else:
logger.warning_once(
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
Expand All @@ -782,6 +798,10 @@ def choose_one(
_, runner_id, tactic, _ = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)

tuning_end_time = time.perf_counter()
self.stats.tuned_op_time_cost[
custom_op] = self.stats.tuned_op_time_cost.get(
custom_op, 0) + tuning_end_time - tuning_start_time
return (runners[runner_id], tactic)

def _profile_runners(
Expand Down Expand Up @@ -832,14 +852,13 @@ def _profile_runners(
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.",
key=(custom_op, "warning_autotuning_profile_failure"),
)
logger.debug_once(
f"[Autotuner] Exception captured: {e}",
key=(custom_op, "debug_autotuning_exception"),
)
(logger.info_once
if self._log_level_to_info else logger.debug_once)(
f"[Autotuner] Exception captured: {e}",
key=(custom_op, "debug_autotuning_exception"),
)

# Record the failed profiling combinations
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
self.stats.failed_profiling_count[custom_op].add(
self.profiling_cache.get_cache_key(
custom_op, runner, profile.get_opt_shapes(),
Expand Down Expand Up @@ -957,7 +976,7 @@ def pure_profile(stream: torch.cuda.Stream, repeat: int):
avg_time = pure_profile(stream, self.repeat)

shapes = self._get_input_sizes(inputs)
logger.debug(
self._debug_logger(
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms."
)

Expand Down Expand Up @@ -1043,7 +1062,7 @@ def _optimization_profiles(
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
min_value, opt_value, max_value)
generated_profiles.append(p)
logger.debug(f"[Autotuner] Generated profile: {p}")
self._debug_logger(f"[Autotuner] Generated profile: {p}")
return generated_profiles

@classmethod
Expand Down Expand Up @@ -1159,7 +1178,7 @@ def _prepare_input_tensors_with_batches(
input.element_size() if isinstance(input, torch.Tensor) else 0
for input in inputs)
if one_buffer_bytes <= 0:
logger.debug(
self._debug_logger(
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
)
return [inputs]
Expand All @@ -1174,7 +1193,7 @@ def _prepare_input_tensors_with_batches(
list(t.clone() if isinstance(t, torch.Tensor) else t
for t in inputs))

logger.debug(
self._debug_logger(
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
)
return inputs_list
Expand All @@ -1188,16 +1207,23 @@ def reset_statistics(self) -> None:
self.stats = AutoTunerStatistics()

def print_profiling_cache(self):
logger.debug(f"[Autotuner] The profiling_cache entries:")
logger.debug(
self._debug_logger(f"[Autotuner] The profiling_cache entries:")
self._debug_logger(
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
)
for key, value in self.profiling_cache.cache.items():
runner_id, tactic, min_time = value
logger.debug(
self._debug_logger(
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
)

self.print_statistics()

def print_statistics(self):
self._debug_logger(f"[Autotuner] The statistics:")
for line in self.stats.__str__().split("\n"):
self._debug_logger(line)

@contextlib.contextmanager
def capture(self):
"""Context manager for capturing execution contexts for testing.
Expand Down Expand Up @@ -1271,7 +1297,7 @@ def replay(self, *config: Tuple[Tuple[TunableRunner, int], ...]):
runner_idx = runners.index(runner)
runner_tactic_list.append((runner_idx, tactic))

logger.debug(
self._debug_logger(
f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}")

# Replay the contexts with given (runner, tactic) pairs
Expand Down
9 changes: 2 additions & 7 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,8 @@ def __init__(self, alpha: float, output_dtype: torch.dtype):
)

# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
def __hash__(self):
return hash((self.output_dtype, ))

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.output_dtype == other.output_dtype
def unique_id(self):
return (self.output_dtype, )

def get_valid_tactics(
self,
Expand Down
Loading