diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 01f83b08096..0715c6d5f6c 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -99,6 +99,7 @@ class TuningConfig: constraint_specs: Tuple[ConstraintSpec, ...] = () tune_max_num_tokens: int = None inputs_pre_hook: Callable = None + use_cuda_graph: bool = False @dataclass(unsafe_hash=True) @@ -522,6 +523,7 @@ class AutoTuner: repeat (int): Number of profiling iterations for averaging (default: 10) stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000) """ + _CUDA_GRAPH_DELAY_MICRO_SECS = 100 _instance = None def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): @@ -534,8 +536,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): # Add statistics tracking self.stats = AutoTunerStatistics() - self.profiling_debug = True - # Current captured choose_one() contexts self._active_capture: Optional['AutoTuner.TacticsCapture'] = None # Last captured choose_one() contexts @@ -727,10 +727,10 @@ def choose_one( new_tuning_failure_occured = False for p in profiles: + tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, *_ = self.profiling_cache.search_cache( custom_op, runners, p.get_opt_shapes(), tuning_config) if not is_cache_hit: - tensors = self._prepare_input_tensors(p, inputs) # Initialize runner and tactic as None in case of no valid tactic or runners are found best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners( custom_op, runners, tensors, p, tuning_config, **kwargs) @@ -811,7 +811,12 @@ def _profile_runners( for tac in valid_tactics: try: time_measured = self._profile_single_kernel( - runner, input_tensors, tac, **kwargs) + runner=runner, + inputs=input_tensors, + tactic=tac, + use_cuda_graph=tuning_config.use_cuda_graph, + **kwargs, + ) except Exception as e: # Handle None tensors for optional inputs shapes = self._get_input_sizes(input_tensors) @@ -857,6 +862,7 @@ def _profile_single_kernel( runner: TunableRunner, inputs: List[torch.Tensor], tactic: Any, + use_cuda_graph: bool = False, **kwargs, ) -> float: """Profile a single kernel implementation for performance measurement. @@ -875,22 +881,40 @@ def _profile_single_kernel( are used to ensure accurate timing. """ stream = torch.cuda.current_stream() - # warm up, no timing - for _ in range(self.warmup): - runner(inputs, tactic=tactic, **kwargs) - stream.synchronize() - - # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. - # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops) - # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity. - delay_kernel(self.stream_delay_micro_secs, stream) + graph = torch.cuda.CUDAGraph() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - start.record(stream=stream) - for _ in range(self.repeat): - runner(inputs, tactic=tactic, **kwargs) - end.record(stream=stream) + with torch.cuda.stream(stream): + # warm up, no timing + for _ in range(self.warmup): + runner(inputs, tactic=tactic, **kwargs) + + if use_cuda_graph: + with torch.cuda.graph(graph): + for _ in range(self.repeat): + runner(inputs, tactic=tactic, **kwargs) + + stream.synchronize() + + # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. + # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops) + # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity. + if use_cuda_graph: + delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream) + else: + delay_kernel(self.stream_delay_micro_secs, stream) + + start.record() + + if use_cuda_graph: + graph.replay() + else: + for _ in range(self.repeat): + runner(inputs, tactic=tactic, **kwargs) + + end.record() + stream.synchronize() avg_time = start.elapsed_time(end) / self.repeat diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 0ec96073256..e67a1512c11 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -36,6 +36,7 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner): 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), + use_cuda_graph=True, ) def __init__(self, alpha: float, output_dtype: torch.dtype): diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index ab11ee0df24..81e62fdf137 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -57,19 +57,19 @@ def test_multi_dynamic_dims(): # add sleep to simulate bad perf def gemm_0(x, w): if x.shape[0] > M // 2: - delay_kernel(10000, torch.cuda.current_stream()) + delay_kernel(100, torch.cuda.current_stream()) return x @ w def gemm_1(x, w): if x.shape[0] <= M // 2: - delay_kernel(10000, torch.cuda.current_stream()) + delay_kernel(100, torch.cuda.current_stream()) return x @ w def gemm_fallback(x, w) -> torch.Tensor: # always the slowest - delay_kernel(100000, torch.cuda.current_stream()) + delay_kernel(500, torch.cuda.current_stream()) return x @ w