Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, List, Set, Tuple, Union

import torch
from cuda.bindings import driver

from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -86,10 +87,12 @@ class TuningConfig:
any value is provided to the choose_one function, the input tensor will be saturated
with the provided value.
If not provided, the autotuner will not consider the max num tokens.
use_cold_l2_cache (bool): If true, use cold L2 cache for profiling.
"""
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
constraint_specs: Tuple[ConstraintSpec, ...] = ()
tune_max_num_tokens: int = None
use_cold_l2_cache: bool = False


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -466,7 +469,7 @@ def _profile_runners(
for tac in valid_tactics:
try:
time_measured = self._profile_single_kernel(
runner, input_tensors, tac, **kwargs)
runner, input_tensors, tac, tuning_config, **kwargs)
except Exception as e:
# Handle None tensors for optional inputs
shapes = self._get_input_sizes(input_tensors)
Expand Down Expand Up @@ -508,6 +511,7 @@ def _profile_single_kernel(
runner: TunableRunner,
inputs: List[torch.Tensor],
tactic: Any,
tuning_config: TuningConfig,
**kwargs,
) -> float:
"""Profile a single kernel implementation for performance measurement.
Expand All @@ -525,10 +529,17 @@ def _profile_single_kernel(
to get an average execution time. Stream synchronization and delays
are used to ensure accurate timing.
"""
input_tensor_baches = self._prepare_input_tensors_with_batches(
inputs, tuning_config)

stream = torch.cuda.current_stream()
# warm up, no timing
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)
runner(
input_tensor_baches[-1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this one, why [-1]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is so that we don't warm up the cache for the first iteration?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. Always using the last one is a simple way to guarantee that every profiling has a cold cache.

tactic=tactic,
**kwargs,
)
stream.synchronize()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
Expand All @@ -539,14 +550,18 @@ def _profile_single_kernel(
end = torch.cuda.Event(enable_timing=True)

start.record(stream=stream)
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)
for r in range(self.repeat):
runner(
input_tensor_baches[r % len(input_tensor_baches)],
tactic=tactic,
**kwargs,
)
end.record(stream=stream)
stream.synchronize()

avg_time = start.elapsed_time(end) / self.repeat

shapes = self._get_input_sizes(inputs)
shapes = self._get_input_sizes(input_tensor_baches[-1])
logger.debug(
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms."
)
Expand Down Expand Up @@ -733,6 +748,39 @@ def _prepare_input_tensors(
tensors.append(tensor)
return tensors

def _prepare_input_tensors_with_batches(
self,
inputs: List[torch.Tensor],
tuning_config: TuningConfig,
) -> List[List[torch.Tensor]]:
if not tuning_config.use_cold_l2_cache:
return [inputs]

one_buffer_bytes = sum(
input.numel() *
input.element_size() if isinstance(input, torch.Tensor) else 0
for input in inputs)
if one_buffer_bytes <= 0:
logger.info(
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
)
return [inputs]

num_buffers = self._get_l2_cache_size_in_bytes(
) * 3 // one_buffer_bytes + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be at least 2, otherwise we will use the same weights which might still be in the cache

num_buffers = min(num_buffers, self.repeat + 1)

inputs_list = [inputs]
for _ in range(num_buffers - 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will increase the GPU memory a lot, by num_buffersx, can we clear L2 cache in another way? w/o increase the memory usage much.

This matters for TRTLLM because we need to use trace based method to record peak memory and estimate KV cache usage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But use_cold_l2 is opt in. If we don't setup this config, it means nothing is done for this mr. Still curious why some tests failed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better approach may be to have a L2 clearing kernel that runs between each iteration. I think you can do this by initialising an auxiliary buffer to maybe 2-3x the size of cache, and then launch a kernel to write random values between iterations (I think it needs to be random)

inputs_list.append(
list(t.clone() if isinstance(t, torch.Tensor) else t
for t in inputs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not going to work for MOE, for MOE we have pointers referring to the weights in the workspace. We will need to call with do_preperation=True separately for every buffer.
This will need some changes to the preparation logic to support multiple internal workspaces too.
For this reason I think it might be better to go for an L2 clearing based approach that also reduces memory (see other comment)


logger.info(
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
)
return inputs_list

def clear_cache(self) -> None:
"""Clear the profiling cache."""
self.profiling_cache.clear()
Expand All @@ -750,3 +798,35 @@ def print_profiling_cache(self):
runner_id, tactic, profile = value
logger.debug(
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})")

def _get_l2_cache_size_in_bytes(self, device_id: int = 0) -> int:
device = self._checkCudaErrors(driver.cuDeviceGet(device_id))
return self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE,
device,
))

def _checkCudaErrors(self, result) -> Any:
status = result[0]
if status != driver.CUresult.CUDA_SUCCESS:
code = getattr(status, "value", status)
raise RuntimeError(
f"CUDA error code={code}({self._cudaGetErrorEnum(status)})")
# CUDA APIs always return the status as the first element of the result tuple
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]

def _cudaGetErrorEnum(self, error) -> str:
from cuda.bindings import nvrtc
if isinstance(error, driver.CUresult):
err, name = driver.cuGetErrorName(error)
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))