Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TuningConfig:
constraint_specs: Tuple[ConstraintSpec, ...] = ()
tune_max_num_tokens: int = None
inputs_pre_hook: Callable = None
use_cuda_graph: bool = False


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -522,6 +523,7 @@ class AutoTuner:
repeat (int): Number of profiling iterations for averaging (default: 10)
stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000)
"""
_CUDA_GRAPH_DELAY_MICRO_SECS = 100
_instance = None

def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
Expand All @@ -534,8 +536,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
# Add statistics tracking
self.stats = AutoTunerStatistics()

self.profiling_debug = True

# Current captured choose_one() contexts
self._active_capture: Optional['AutoTuner.TacticsCapture'] = None
# Last captured choose_one() contexts
Expand Down Expand Up @@ -727,10 +727,10 @@ def choose_one(
new_tuning_failure_occured = False

for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, *_ = self.profiling_cache.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config)
if not is_cache_hit:
tensors = self._prepare_input_tensors(p, inputs)
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
custom_op, runners, tensors, p, tuning_config, **kwargs)
Expand Down Expand Up @@ -811,7 +811,12 @@ def _profile_runners(
for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
runner, input_tensors, tac, **kwargs)
runner=runner,
inputs=input_tensors,
tactic=tac,
use_cuda_graph=tuning_config.use_cuda_graph,
**kwargs,
)
except Exception as e:
# Handle None tensors for optional inputs
shapes = self._get_input_sizes(input_tensors)
Expand Down Expand Up @@ -857,6 +862,7 @@ def _profile_single_kernel(
runner: TunableRunner,
inputs: List[torch.Tensor],
tactic: Any,
use_cuda_graph: bool = False,
**kwargs,
) -> float:
"""Profile a single kernel implementation for performance measurement.
Expand All @@ -875,22 +881,40 @@ def _profile_single_kernel(
are used to ensure accurate timing.
"""
stream = torch.cuda.current_stream()
# warm up, no timing
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)
stream.synchronize()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
delay_kernel(self.stream_delay_micro_secs, stream)
graph = torch.cuda.CUDAGraph()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record(stream=stream)
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)
end.record(stream=stream)
with torch.cuda.stream(stream):
# warm up, no timing
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)

if use_cuda_graph:
with torch.cuda.graph(graph):
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)

stream.synchronize()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
if use_cuda_graph:
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
else:
delay_kernel(self.stream_delay_micro_secs, stream)

start.record()

if use_cuda_graph:
graph.replay()
else:
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)

end.record()

stream.synchronize()

avg_time = start.elapsed_time(end) / self.repeat
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
use_cuda_graph=True,
)

def __init__(self, alpha: float, output_dtype: torch.dtype):
Expand Down
6 changes: 3 additions & 3 deletions tests/unittest/_torch/misc/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ def test_multi_dynamic_dims():
# add sleep to simulate bad perf
def gemm_0(x, w):
if x.shape[0] > M // 2:
delay_kernel(10000, torch.cuda.current_stream())
delay_kernel(100, torch.cuda.current_stream())
return x @ w


def gemm_1(x, w):
if x.shape[0] <= M // 2:
delay_kernel(10000, torch.cuda.current_stream())
delay_kernel(100, torch.cuda.current_stream())
return x @ w


def gemm_fallback(x, w) -> torch.Tensor:
# always the slowest
delay_kernel(100000, torch.cuda.current_stream())
delay_kernel(500, torch.cuda.current_stream())
return x @ w


Expand Down