Skip to content

Commit c7bf852

Browse files
committed
changes to account for the base branch change
1 parent 56b80db commit c7bf852

File tree

7 files changed

+11
-275
lines changed

7 files changed

+11
-275
lines changed

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.distributed._tensor.device_mesh import init_device_mesh
1818

1919

20+
# this is kept at the application level, when mpirun is used to run the application
2021
def initialize_distributed_env(rank=0, world_size=1, port=29500):
2122
local_rank = int(
2223
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@
2626
import torch_tensorrt
2727
from torch_tensorrt.dynamo.distributed.utils import (
2828
get_tensor_parallel_device_mesh,
29-
initialize_logger,
29+
initialize_distributed_logger,
3030
)
3131

3232
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
33-
logger = initialize_logger(_rank, "tensor_parallel_rotary_embedding")
33+
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")
3434

3535
from rotary_embedding import RotaryAttention, parallel_rotary_block
3636

3737
"""
3838
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3939
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
40+
Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
4041
"""
4142

4243
BATCH = 2

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
)
4242
from torch_tensorrt.dynamo.distributed.utils import (
4343
get_tensor_parallel_device_mesh,
44-
initialize_logger,
44+
initialize_distributed_logger,
4545
)
4646

4747
device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
48-
logger = initialize_logger(_rank, "tensor_parallel_simple_example")
48+
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")
4949

5050

5151
"""

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1313
dynamo_tensorrt_converter,
1414
)
15-
from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl
1615
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
1716
tensorrt_fused_nccl_all_gather_op,
1817
tensorrt_fused_nccl_reduce_scatter_op,
Lines changed: 1 addition & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
import ctypes
2-
import getpass
31
import logging
42
import os
5-
import platform
6-
import tempfile
7-
import urllib.request
8-
from pathlib import Path
9-
from typing import Optional
103

114
import torch
125
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
13-
from torch_tensorrt._version import __tensorrt_llm_version__
14-
15-
_WHL_CPYTHON_VERSION = "cp310"
166

177
logger = logging.getLogger(__name__)
188

@@ -42,268 +32,10 @@ def get_tensor_parallel_device_mesh(
4232
return device_mesh, world_size, rank
4333

4434

45-
def initialize_logger(rank: int, logger_file_name: str) -> logging.Logger:
35+
def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
4636
logger = logging.getLogger()
4737
logger.setLevel(logging.INFO)
4838
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
4939
fh.setLevel(logging.INFO)
5040
logger.addHandler(fh)
5141
return logger
52-
53-
54-
def is_platform_supported_for_trtllm() -> bool:
55-
"""
56-
Checks if the current platform supports TensorRT-LLM plugins for the NCCL backend.
57-
58-
Returns:
59-
bool: True if supported, False otherwise.
60-
61-
Unsupported:
62-
- Windows platforms
63-
- Jetson/Orin/Xavier (aarch64 architecture + 'tegra' in platform release)
64-
- CUDA 13 not supported
65-
"""
66-
system = platform.system().lower()
67-
machine = platform.machine().lower()
68-
release = platform.release().lower()
69-
70-
if "windows" in system:
71-
logger.info(
72-
"TensorRT-LLM plugins for NCCL backend are not supported on Windows."
73-
)
74-
return False
75-
76-
if machine == "aarch64" and "tegra" in release:
77-
logger.info(
78-
"TensorRT-LLM plugins for NCCL backend are not supported on Jetson/Orin/Xavier (Tegra) devices."
79-
)
80-
return False
81-
82-
try:
83-
cuda_version = torch.version.cuda # e.g., "12.4" or "13.0"
84-
if cuda_version is None:
85-
logger.warning("No CUDA runtime detected — TRT-LLM plugins unavailable.")
86-
return False
87-
88-
major, minor = map(int, cuda_version.split("."))
89-
if major != 12:
90-
logger.warning("CUDA 13 is not supported for TRT-LLM plugins.")
91-
return False
92-
93-
return True
94-
95-
except Exception as e:
96-
logger.warning(f"Failed to detect CUDA version: {e}")
97-
return False
98-
99-
return True
100-
101-
102-
def _cache_root() -> Path:
103-
username = getpass.getuser()
104-
return Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
105-
106-
107-
def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
108-
return (
109-
_cache_root()
110-
/ "trtllm"
111-
/ f"{__tensorrt_llm_version__}_{platform_system}_{platform_machine}"
112-
)
113-
114-
115-
def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
116-
from torch.distributed import barrier, get_rank, is_initialized
117-
118-
if not is_initialized():
119-
# Single process case, just unzip
120-
is_master = True
121-
else:
122-
is_master = get_rank() == 0 # only rank 0 does the unzip
123-
124-
if is_master:
125-
try:
126-
import zipfile
127-
except ImportError as e:
128-
raise ImportError(
129-
"zipfile module is required but not found. Please install zipfile"
130-
)
131-
try:
132-
with zipfile.ZipFile(wheel_path) as zip_ref:
133-
zip_ref.extractall(extract_dir)
134-
logger.debug(f"Extracted wheel to {extract_dir}")
135-
136-
except FileNotFoundError as e:
137-
# This should capture the errors in the download failure above
138-
logger.error(f"Wheel file not found at {wheel_path}: {e}")
139-
raise RuntimeError(
140-
f"Failed to find downloaded wheel file at {wheel_path}"
141-
) from e
142-
except zipfile.BadZipFile as e:
143-
logger.error(f"Invalid or corrupted wheel file: {e}")
144-
raise RuntimeError(
145-
"Downloaded wheel file is corrupted or not a valid zip archive"
146-
) from e
147-
except Exception as e:
148-
logger.error(f"Unexpected error while extracting wheel: {e}")
149-
raise RuntimeError(
150-
"Unexpected error during extraction of TensorRT-LLM wheel"
151-
) from e
152-
153-
# Make sure others wait until unzip is done
154-
if is_initialized():
155-
barrier()
156-
157-
158-
def download_and_get_plugin_lib_path() -> Optional[str]:
159-
"""
160-
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.
161-
162-
Args:
163-
platform (str): Platform identifier (e.g., 'linux_x86_64')
164-
165-
Returns:
166-
Optional[str]: Path to shared library or None if operation fails.
167-
"""
168-
platform_system = platform.system().lower()
169-
platform_machine = platform.machine().lower()
170-
wheel_filename = (
171-
f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-"
172-
f"{_WHL_CPYTHON_VERSION}-{platform_system}_{platform_machine}.whl"
173-
)
174-
wheel_path = _cache_root() / wheel_filename
175-
extract_dir = _extracted_dir_trtllm(platform_system, platform_machine)
176-
# else will never be met though
177-
lib_filename = (
178-
"libnvinfer_plugin_tensorrt_llm.so"
179-
if "linux" in platform_system
180-
else "libnvinfer_plugin_tensorrt_llm.dll"
181-
)
182-
# eg: /tmp/torch_tensorrt_<username>/trtllm/0.17.0.post1_linux_x86_64/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so
183-
plugin_lib_path = extract_dir / "tensorrt_llm" / "libs" / lib_filename
184-
185-
if plugin_lib_path.exists():
186-
return str(plugin_lib_path)
187-
188-
wheel_path.parent.mkdir(parents=True, exist_ok=True)
189-
extract_dir.mkdir(parents=True, exist_ok=True)
190-
191-
if not wheel_path.exists():
192-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
193-
download_url = base_url + wheel_filename
194-
try:
195-
logger.debug(f"Downloading {download_url} ...")
196-
urllib.request.urlretrieve(download_url, wheel_path)
197-
logger.debug("Download succeeded and TRT-LLM wheel is now present")
198-
except urllib.error.HTTPError as e:
199-
logger.error(
200-
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
201-
)
202-
except urllib.error.URLError as e:
203-
logger.error(
204-
f"URL error when trying to download {download_url}: {e.reason}"
205-
)
206-
except OSError as e:
207-
logger.error(f"Local file write error: {e}")
208-
209-
extract_wheel_file(wheel_path, extract_dir)
210-
211-
try:
212-
wheel_path.unlink(missing_ok=True)
213-
logger.debug(f"Deleted wheel file: {wheel_path}")
214-
except Exception as e:
215-
logger.warning(f"Could not delete wheel file {wheel_path}: {e}")
216-
if not plugin_lib_path.exists():
217-
logger.error(
218-
f"Plugin library not found at expected location: {plugin_lib_path}"
219-
)
220-
return None
221-
222-
return str(plugin_lib_path)
223-
224-
225-
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
226-
"""
227-
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
228-
229-
Args:
230-
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
231-
232-
Returns:
233-
bool: True if successful, False otherwise.
234-
"""
235-
try:
236-
handle = ctypes.CDLL(plugin_lib_path)
237-
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
238-
except OSError as e_os_error:
239-
if "libmpi" in str(e_os_error):
240-
logger.warning(
241-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}, got error {e_os_error} (hint: libmpi.so is a necessary dependency; ensure that OpenMPI or MPICH is installed on your system)",
242-
exc_info=e_os_error,
243-
)
244-
else:
245-
logger.warning(
246-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
247-
f"Ensure the path is correct and the library is compatible.",
248-
exc_info=e_os_error,
249-
)
250-
return False
251-
252-
try:
253-
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
254-
handle.initTrtLlmPlugins.restype = ctypes.c_bool
255-
except AttributeError as e_plugin_unavailable:
256-
logger.warning(
257-
"Unable to initialize the TensorRT-LLM plugin library",
258-
exc_info=e_plugin_unavailable,
259-
)
260-
return False
261-
262-
try:
263-
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
264-
logger.info("TensorRT-LLM plugin successfully initialized")
265-
return True
266-
else:
267-
logger.warning("TensorRT-LLM plugin library failed in initialization")
268-
return False
269-
except Exception as e_initialization_error:
270-
logger.warning(
271-
"Exception occurred during TensorRT-LLM plugin library initialization",
272-
exc_info=e_initialization_error,
273-
)
274-
return False
275-
return False
276-
277-
278-
def load_tensorrt_llm_for_nccl() -> bool:
279-
"""
280-
Attempts to load the TensorRT-LLM plugin and initialize it.
281-
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
282-
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
283-
284-
Returns:
285-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
286-
"""
287-
if not is_platform_supported_for_trtllm():
288-
return False
289-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
290-
291-
if plugin_lib_path:
292-
return load_and_initialize_trtllm_plugin(plugin_lib_path)
293-
else:
294-
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
295-
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
296-
"1",
297-
"true",
298-
"yes",
299-
"on",
300-
)
301-
if not use_trtllm_plugin:
302-
logger.warning(
303-
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
304-
)
305-
return False
306-
307-
plugin_lib_path = download_and_get_plugin_lib_path()
308-
return load_and_initialize_trtllm_plugin(plugin_lib_path) # type: ignore[arg-type]
309-
return False

tests/py/dynamo/distributed/distributed_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from torch.distributed._tensor.device_mesh import init_device_mesh
1010

1111

12+
# the below two functions are used to set the environment variables for the pytest single and multi process
13+
# this is for the github CI where we use pytest
1214
def set_environment_variables_pytest_single_process():
1315
port = 29500 + random.randint(1, 1000)
1416
os.environ["WORLD_SIZE"] = str(1)

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
from conversion.harness import DispatchTestCase
88

9-
# The distributed env initialization has to be before torchTRT import since it uses barrier
9+
# The distributed env initialization has to be before import of torchTRT, since it uses barrier for installation
1010
from distributed_utils import (
1111
set_environment_variables_pytest_multi_process,
1212
set_environment_variables_pytest_single_process,
@@ -26,6 +26,7 @@
2626
init_method="env://",
2727
)
2828

29+
2930
class DistributedGatherModel(nn.Module):
3031
def __init__(self, input_dim, world_size, group_name):
3132
super().__init__()

0 commit comments

Comments
 (0)