|  | 
| 1 |  | -import ctypes | 
| 2 |  | -import getpass | 
| 3 | 1 | import logging | 
| 4 | 2 | import os | 
| 5 |  | -import platform | 
| 6 |  | -import tempfile | 
| 7 |  | -import urllib.request | 
| 8 |  | -from pathlib import Path | 
| 9 |  | -from typing import Optional | 
| 10 | 3 | 
 | 
| 11 | 4 | import torch | 
| 12 | 5 | from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh | 
| 13 |  | -from torch_tensorrt._version import __tensorrt_llm_version__ | 
| 14 |  | - | 
| 15 |  | -_WHL_CPYTHON_VERSION = "cp310" | 
| 16 | 6 | 
 | 
| 17 | 7 | logger = logging.getLogger(__name__) | 
| 18 | 8 | 
 | 
| @@ -42,268 +32,10 @@ def get_tensor_parallel_device_mesh( | 
| 42 | 32 |     return device_mesh, world_size, rank | 
| 43 | 33 | 
 | 
| 44 | 34 | 
 | 
| 45 |  | -def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger: | 
|  | 35 | +def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger: | 
| 46 | 36 |     logger = logging.getLogger() | 
| 47 | 37 |     logger.setLevel(logging.INFO) | 
| 48 | 38 |     fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") | 
| 49 | 39 |     fh.setLevel(logging.INFO) | 
| 50 | 40 |     logger.addHandler(fh) | 
| 51 | 41 |     return logger | 
| 52 |  | - | 
| 53 |  | - | 
| 54 |  | -def is_platform_supported_for_trtllm() -> bool: | 
| 55 |  | -    """ | 
| 56 |  | -    Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend. | 
| 57 |  | -
 | 
| 58 |  | -    Returns: | 
| 59 |  | -        bool: True if supported, False otherwise. | 
| 60 |  | -
 | 
| 61 |  | -    Unsupported: | 
| 62 |  | -        - Windows platforms | 
| 63 |  | -        - Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release) | 
| 64 |  | -        - CUDA 13 not supported | 
| 65 |  | -    """ | 
| 66 |  | -    system = platform.system().lower() | 
| 67 |  | -    machine = platform.machine().lower() | 
| 68 |  | -    release = platform.release().lower() | 
| 69 |  | - | 
| 70 |  | -    if "windows" in system: | 
| 71 |  | -        logger.info( | 
| 72 |  | -            "TensorRT-LLM plugins for NCCL backend are not supported on Windows." | 
| 73 |  | -        ) | 
| 74 |  | -        return False | 
| 75 |  | - | 
| 76 |  | -    if machine == "aarch64" and "tegra" in release: | 
| 77 |  | -        logger.info( | 
| 78 |  | -            "TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices." | 
| 79 |  | -        ) | 
| 80 |  | -        return False | 
| 81 |  | - | 
| 82 |  | -    try: | 
| 83 |  | -        cuda_version = torch.version.cuda  # e.g., "12.4" or "13.0" | 
| 84 |  | -        if cuda_version is None: | 
| 85 |  | -            logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.") | 
| 86 |  | -            return False | 
| 87 |  | - | 
| 88 |  | -        major, minor = map(int, cuda_version.split(".")) | 
| 89 |  | -        if major != 12: | 
| 90 |  | -            logger.warning("CUDA 13 is not supported for TRT-LLM plugins.") | 
| 91 |  | -            return False | 
| 92 |  | - | 
| 93 |  | -        return True | 
| 94 |  | - | 
| 95 |  | -    except Exception as e: | 
| 96 |  | -        logger.warning(f"Failed to detect CUDA version: {e}") | 
| 97 |  | -        return False | 
| 98 |  | - | 
| 99 |  | -    return True | 
| 100 |  | - | 
| 101 |  | - | 
| 102 |  | -def _cache_root() -> Path: | 
| 103 |  | -    username = getpass.getuser() | 
| 104 |  | -    return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" | 
| 105 |  | - | 
| 106 |  | - | 
| 107 |  | -def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: | 
| 108 |  | -    return ( | 
| 109 |  | -        _cache_root() | 
| 110 |  | -        / "trtllm" | 
| 111 |  | -        / f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}" | 
| 112 |  | -    ) | 
| 113 |  | - | 
| 114 |  | - | 
| 115 |  | -def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: | 
| 116 |  | -    from torch.distributed import barrier, get_rank, is_initialized | 
| 117 |  | - | 
| 118 |  | -    if not is_initialized(): | 
| 119 |  | -        # Single process case, just unzip | 
| 120 |  | -        is_master = True | 
| 121 |  | -    else: | 
| 122 |  | -        is_master = get_rank() == 0  # only rank 0 does the unzip | 
| 123 |  | - | 
| 124 |  | -    if is_master: | 
| 125 |  | -        try: | 
| 126 |  | -            import zipfile | 
| 127 |  | -        except ImportError as e: | 
| 128 |  | -            raise ImportError( | 
| 129 |  | -                "zipfile module is required but not found. Please install zipfile" | 
| 130 |  | -            ) | 
| 131 |  | -        try: | 
| 132 |  | -            with zipfile.ZipFile(wheel_path) as zip_ref: | 
| 133 |  | -                zip_ref.extractall(extract_dir) | 
| 134 |  | -                logger.debug(f"Extracted wheel to {extract_dir}") | 
| 135 |  | - | 
| 136 |  | -        except FileNotFoundError as e: | 
| 137 |  | -            # This should capture the errors in the download failure above | 
| 138 |  | -            logger.error(f"Wheel file not found at {wheel_path}: {e}") | 
| 139 |  | -            raise RuntimeError( | 
| 140 |  | -                f"Failed to find downloaded wheel file at {wheel_path}" | 
| 141 |  | -            ) from e | 
| 142 |  | -        except zipfile.BadZipFile as e: | 
| 143 |  | -            logger.error(f"Invalid or corrupted wheel file: {e}") | 
| 144 |  | -            raise RuntimeError( | 
| 145 |  | -                "Downloaded wheel file is corrupted or not a valid zip archive" | 
| 146 |  | -            ) from e | 
| 147 |  | -        except Exception as e: | 
| 148 |  | -            logger.error(f"Unexpected error while extracting wheel: {e}") | 
| 149 |  | -            raise RuntimeError( | 
| 150 |  | -                "Unexpected error during extraction of TensorRT-LLM wheel" | 
| 151 |  | -            ) from e | 
| 152 |  | - | 
| 153 |  | -    # Make sure others wait until unzip is done | 
| 154 |  | -    if is_initialized(): | 
| 155 |  | -        barrier() | 
| 156 |  | - | 
| 157 |  | - | 
| 158 |  | -def download_and_get_plugin_lib_path() -> Optional[str]: | 
| 159 |  | -    """ | 
| 160 |  | -    Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary. | 
| 161 |  | -
 | 
| 162 |  | -    Args: | 
| 163 |  | -        platform (str): Platform identifier (e.g., 'linux_x86_64') | 
| 164 |  | -
 | 
| 165 |  | -    Returns: | 
| 166 |  | -        Optional[str]: Path to shared library or None if operation fails. | 
| 167 |  | -    """ | 
| 168 |  | -    platform_system = platform.system().lower() | 
| 169 |  | -    platform_machine = platform.machine().lower() | 
| 170 |  | -    wheel_filename = ( | 
| 171 |  | -        f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-" | 
| 172 |  | -        f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl" | 
| 173 |  | -    ) | 
| 174 |  | -    wheel_path = _cache_root() / wheel_filename | 
| 175 |  | -    extract_dir = _extracted_dir_trtllm(platform_system, platform_machine) | 
| 176 |  | -    # else will never be met though | 
| 177 |  | -    lib_filename = ( | 
| 178 |  | -        "libnvinfer_plugin_tensorrt_llm.so" | 
| 179 |  | -        if "linux" in platform_system | 
| 180 |  | -        else "libnvinfer_plugin_tensorrt_llm.dll" | 
| 181 |  | -    ) | 
| 182 |  | -    # eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so | 
| 183 |  | -    plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename | 
| 184 |  | - | 
| 185 |  | -    if plugin_lib_path.exists(): | 
| 186 |  | -        return str(plugin_lib_path) | 
| 187 |  | - | 
| 188 |  | -    wheel_path.parent.mkdir(parents=True, exist_ok=True) | 
| 189 |  | -    extract_dir.mkdir(parents=True, exist_ok=True) | 
| 190 |  | - | 
| 191 |  | -    if not wheel_path.exists(): | 
| 192 |  | -        base_url = "https://pypi.nvidia.com/tensorrt-llm/" | 
| 193 |  | -        download_url = base_url + wheel_filename | 
| 194 |  | -        try: | 
| 195 |  | -            logger.debug(f"Downloading {download_url} ...") | 
| 196 |  | -            urllib.request.urlretrieve(download_url, wheel_path) | 
| 197 |  | -            logger.debug("Download succeeded and TRT-LLM wheel is now present") | 
| 198 |  | -        except urllib.error.HTTPError as e: | 
| 199 |  | -            logger.error( | 
| 200 |  | -                f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" | 
| 201 |  | -            ) | 
| 202 |  | -        except urllib.error.URLError as e: | 
| 203 |  | -            logger.error( | 
| 204 |  | -                f"URL error when trying to download {download_url}: {e.reason}" | 
| 205 |  | -            ) | 
| 206 |  | -        except OSError as e: | 
| 207 |  | -            logger.error(f"Local file write error: {e}") | 
| 208 |  | - | 
| 209 |  | -    extract_wheel_file(wheel_path, extract_dir) | 
| 210 |  | - | 
| 211 |  | -    try: | 
| 212 |  | -        wheel_path.unlink(missing_ok=True) | 
| 213 |  | -        logger.debug(f"Deleted wheel file: {wheel_path}") | 
| 214 |  | -    except Exception as e: | 
| 215 |  | -        logger.warning(f"Could not delete wheel file {wheel_path}: {e}") | 
| 216 |  | -    if not plugin_lib_path.exists(): | 
| 217 |  | -        logger.error( | 
| 218 |  | -            f"Plugin library not found at expected location: {plugin_lib_path}" | 
| 219 |  | -        ) | 
| 220 |  | -        return None | 
| 221 |  | - | 
| 222 |  | -    return str(plugin_lib_path) | 
| 223 |  | - | 
| 224 |  | - | 
| 225 |  | -def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: | 
| 226 |  | -    """ | 
| 227 |  | -    Loads and initializes the TensorRT-LLM plugin from the given shared library path. | 
| 228 |  | -
 | 
| 229 |  | -    Args: | 
| 230 |  | -        plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. | 
| 231 |  | -
 | 
| 232 |  | -    Returns: | 
| 233 |  | -        bool: True if successful, False otherwise. | 
| 234 |  | -    """ | 
| 235 |  | -    try: | 
| 236 |  | -        handle = ctypes.CDLL(plugin_lib_path) | 
| 237 |  | -        logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") | 
| 238 |  | -    except OSError as e_os_error: | 
| 239 |  | -        if "libmpi" in str(e_os_error): | 
| 240 |  | -            logger.warning( | 
| 241 |  | -                f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)", | 
| 242 |  | -                exc_info=e_os_error, | 
| 243 |  | -            ) | 
| 244 |  | -        else: | 
| 245 |  | -            logger.warning( | 
| 246 |  | -                f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " | 
| 247 |  | -                f"Ensure the path is correct and the library is compatible.", | 
| 248 |  | -                exc_info=e_os_error, | 
| 249 |  | -            ) | 
| 250 |  | -        return False | 
| 251 |  | - | 
| 252 |  | -    try: | 
| 253 |  | -        handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] | 
| 254 |  | -        handle.initTrtLlmPlugins.restype = ctypes.c_bool | 
| 255 |  | -    except AttributeError as e_plugin_unavailable: | 
| 256 |  | -        logger.warning( | 
| 257 |  | -            "Unable to initialize the TensorRT-LLM plugin library", | 
| 258 |  | -            exc_info=e_plugin_unavailable, | 
| 259 |  | -        ) | 
| 260 |  | -        return False | 
| 261 |  | - | 
| 262 |  | -    try: | 
| 263 |  | -        if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): | 
| 264 |  | -            logger.info("TensorRT-LLM plugin successfully initialized") | 
| 265 |  | -            return True | 
| 266 |  | -        else: | 
| 267 |  | -            logger.warning("TensorRT-LLM plugin library failed in initialization") | 
| 268 |  | -            return False | 
| 269 |  | -    except Exception as e_initialization_error: | 
| 270 |  | -        logger.warning( | 
| 271 |  | -            "Exception occurred during TensorRT-LLM plugin library initialization", | 
| 272 |  | -            exc_info=e_initialization_error, | 
| 273 |  | -        ) | 
| 274 |  | -        return False | 
| 275 |  | -    return False | 
| 276 |  | - | 
| 277 |  | - | 
| 278 |  | -def load_tensorrt_llm_for_nccl() -> bool: | 
| 279 |  | -    """ | 
| 280 |  | -    Attempts to load the TensorRT-LLM plugin and initialize it. | 
| 281 |  | -    Either the env variable TRTLLM_PLUGINS_PATH can specify the path | 
| 282 |  | -    Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it | 
| 283 |  | -
 | 
| 284 |  | -    Returns: | 
| 285 |  | -        bool: True if the plugin was successfully loaded and initialized, False otherwise. | 
| 286 |  | -    """ | 
| 287 |  | -    if not is_platform_supported_for_trtllm(): | 
| 288 |  | -        return False | 
| 289 |  | -    plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") | 
| 290 |  | - | 
| 291 |  | -    if plugin_lib_path: | 
| 292 |  | -        return load_and_initialize_trtllm_plugin(plugin_lib_path) | 
| 293 |  | -    else: | 
| 294 |  | -        # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user | 
| 295 |  | -        use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( | 
| 296 |  | -            "1", | 
| 297 |  | -            "true", | 
| 298 |  | -            "yes", | 
| 299 |  | -            "on", | 
| 300 |  | -        ) | 
| 301 |  | -        if not use_trtllm_plugin: | 
| 302 |  | -            logger.warning( | 
| 303 |  | -                "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" | 
| 304 |  | -            ) | 
| 305 |  | -            return False | 
| 306 |  | - | 
| 307 |  | -        plugin_lib_path = download_and_get_plugin_lib_path() | 
| 308 |  | -        return load_and_initialize_trtllm_plugin(plugin_lib_path)  # type: ignore[arg-type] | 
| 309 |  | -    return False | 
0 commit comments