diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index b1d5aee28ac..21018e241da 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -959,6 +959,13 @@ class AllreduceOp // MIN_LATENCY. if (mStrategy != AllReduceStrategyType::AUTO) { + // Check TWOSHOT constraint: seq_len >= tp_size + if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size()) + { + TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT", + seq_len, mGroup.size()); + return AllReduceStrategyType::ONESHOT; + } return mStrategy; } else diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index fabde3af980..e99c31e8f98 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -80,6 +80,7 @@ transforms: sharding_source: ['heuristic'] support_partial_config: true sharding_dims: ['tp', 'ep', 'bmm'] + allreduce_strategy: 'AUTO' requires_shape_prop: true # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 7d19c9cd624..fa006eb2bc4 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -13,6 +13,36 @@ # warmup causes hangs due to workspace allocation with CPU synchronization _allreduce_cache = {} + # Global AllReduce Strategy Configuration + # ========================================= + # This global variable controls which allreduce implementation is used across + # all distributed operations in AutoDeploy. It's set once at initialization + # time via set_allreduce_strategy() and remains constant during execution. + _global_allreduce_strategy = AllReduceStrategy.AUTO + + def set_allreduce_strategy(strategy: AllReduceStrategy): + """Set the global allreduce strategy for all distributed operations. + + This should be called once during initialization, before any distributed + operations are executed. All subsequent allreduce calls will use this strategy. + + Note: + This clears the allreduce cache to ensure new operations use the updated strategy. + Call this before any model compilation or CUDA graph capture. + """ + global _global_allreduce_strategy + _global_allreduce_strategy = strategy + # Clear cache when strategy changes to force recreation with new strategy + _allreduce_cache.clear() + + def get_allreduce_strategy() -> AllReduceStrategy: + """Get the current global allreduce strategy. + + Returns: + The currently configured AllReduceStrategy enum value. + """ + return _global_allreduce_strategy + def trtllm_allgather(tensor, dim, sizes=None): rank, world_size = get_rank_world_size() p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) @@ -22,13 +52,13 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): rank, world_size = get_rank_world_size() assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." - # Cache key includes rank, world_size, and dtype to handle different configurations - cache_key = (rank, world_size, tensor.dtype) + # Cache key includes rank, world_size, dtype, and strategy to handle different configurations + cache_key = (rank, world_size, tensor.dtype, _global_allreduce_strategy) if cache_key not in _allreduce_cache: p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - # Use Strategy.AUTO for optimal performance + # Use the configured global strategy _allreduce_cache[cache_key] = AllReduce( - mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype + mapping=p_config, strategy=_global_allreduce_strategy, dtype=tensor.dtype ) torch_op = _allreduce_cache[cache_key] @@ -59,6 +89,12 @@ def fused_allreduce_residual_rmsnorm_fake( TRTLLM_OP_AVAILABLE = True except ImportError: + def set_allreduce_strategy(strategy): + raise ImportError("TRT-LLM is not available.") + + def get_allreduce_strategy(): + raise ImportError("TRT-LLM is not available.") + def trtllm_allgather(tensor, dim, sizes=None): raise ImportError("TRT-LLM is not available.") diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 292652d1532..0bbbf3979c5 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -22,9 +22,10 @@ from typing import DefaultDict, Dict, List, Set, Tuple, Type import torch -from pydantic import Field +from pydantic import Field, field_validator from torch.fx import GraphModule, Node +from .....functional import AllReduceStrategy from ...models.factory import ModelFactory, ShardingConfigSource from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger @@ -149,6 +150,32 @@ class ShardingTransformConfig(TransformConfig): sharding_dims: List[ShardingDim] = Field( default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] ) + allreduce_strategy: AllReduceStrategy = Field( + default=AllReduceStrategy.AUTO, + description="AllReduce strategy for distributed operations. Options: AUTO (automatic selection), " + "NCCL (NCCL-based), ONESHOT (single-phase fusion kernel), TWOSHOT (two-phase fusion kernel), " + "MIN_LATENCY (minimum latency heuristic), LOWPRECISION (low precision allreduce), " + "UB (unified buffer), MNNVL (multi-node NVLINK), NCCL_SYMMETRIC (NCCL symmetric). " + "This is set as a global variable during transform application.", + ) + + @field_validator("allreduce_strategy", mode="before") + @classmethod + def _validate_allreduce_strategy(cls, v): + """Convert string names like 'AUTO' or 'ONESHOT' to AllReduceStrategy enum.""" + if isinstance(v, AllReduceStrategy): + return v + if isinstance(v, str): + try: + return AllReduceStrategy[v] + except KeyError: + raise ValueError( + f"Invalid allreduce strategy: {v}. " + f"Valid options: {', '.join(s.name for s in AllReduceStrategy)}" + ) + if isinstance(v, int): + return AllReduceStrategy(v) + return v @TransformRegistry.register("detect_sharding") @@ -186,6 +213,23 @@ def _apply( local_rank, world_size = shared_config.local_rank, shared_config.world_size # world_size = 2 + # Configure global allreduce strategy from transform config + # This is set once during sharding transform and used by all distributed operations + if hasattr(self.config, "allreduce_strategy"): + try: + from ...distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy + + if TRTLLM_OP_AVAILABLE: + # config.allreduce_strategy is already an AllReduceStrategy enum + set_allreduce_strategy(self.config.allreduce_strategy) + if self.config.allreduce_strategy != AllReduceStrategy.AUTO: + ad_logger.info( + f"Global allreduce strategy configured from transform: " + f"{self.config.allreduce_strategy.name}" + ) + except (ImportError, AttributeError) as e: + ad_logger.warning(f"Failed to set allreduce strategy: {e}") + if world_size < 2: ad_logger.info("Skipping sharding for single device") return gm, TransformInfo( diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 5010be4ea9d..60e12e98207 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -581,13 +581,44 @@ def set_workspace_tensor(self, @staticmethod def max_workspace_size_auto(tp_size: int, support_deterministic=True) -> int: + """Calculate workspace size for allreduce fusion kernel. + + The workspace is used for lamport buffers in the fusion kernel. + Required size calculation: + - Each GPU needs 3 sub-buffers (for triple buffering) + - Each sub-buffer stores: max_num_tokens * hidden_size * dtype_size (bf16=2) + - The lamport allocation multiplies by tp_size, so: + lamport_size = 3 * size * tp_size (per GPU) + + Example: Llama 8B (hidden=4096), max_tokens=8192, bf16, TP=4 + - Data per sub-buffer: 8192 * 4096 * 2 = 64 MiB + - Total lamport: 3 * 64MB * 4 = 768 MiB per GPU + - Required 'size' parameter: 64 MiB (gets multiplied by tp_size in allocation) + + Default (67,108,864 = 64 MiB) supports: + - Models up to hidden_size=4096 with max_num_tokens=8192 + - Or hidden_size=8192 with max_num_tokens=4096 + + Override with TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE env var if needed for larger models. + """ if force_all_reduce_deterministic() and support_deterministic: workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE", "1000000000") return int(workspace_size) - if tp_size <= 2: - return 16_000_000 - return 8_000_000 + + # Allow override via environment variable for edge cases + workspace_size_env = os.getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE") + if workspace_size_env: + size = int(workspace_size_env) + logger.info( + f"Using custom allreduce fusion workspace size: {size} bytes ({size / (1024**2):.1f} MiB)" + ) + return size + + # Default: 64 MiB - supports most common model configurations + # Increase via env var if you see CUDA illegal memory access errors with large models + default_size = 67_108_864 # Exactly 64 MiB + return default_size @staticmethod def max_workspace_size_lowprecision(tp_size: int) -> int: diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py new file mode 100644 index 00000000000..984eb27c839 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py @@ -0,0 +1,196 @@ +import signal +import subprocess +import tempfile +from contextlib import contextmanager +from pathlib import Path + +import pytest +import torch +import yaml +from click.testing import CliRunner +from utils.cpp_paths import llm_root # noqa: F401 + +from tensorrt_llm.commands.bench import main + + +class TimeoutError(Exception): + """Exception raised when a test times out.""" + + pass + + +@contextmanager +def timeout(seconds): + """Context manager that raises TimeoutError if code block exceeds time limit. + + Args: + seconds: Maximum time in seconds to allow the code block to run + + Raises: + TimeoutError: If the code block execution exceeds the time limit + """ + + def timeout_handler(signum, frame): + raise TimeoutError(f"Test execution exceeded {seconds} seconds timeout") + + # Set the signal handler and alarm + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + try: + yield + finally: + # Restore the old signal handler and cancel the alarm + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +@pytest.fixture(scope="module") +def shared_dataset(llm_root): # noqa: F811 + """Prepare dataset once for all tests in this module.""" + model_name = "meta-llama/Llama-3.1-8B" + with tempfile.TemporaryDirectory() as temp_dir: + dataset_path = _prepare_dataset(llm_root, temp_dir, model_name, num_requests=10) + # Read dataset content to return it (temp_dir will be deleted) + with open(dataset_path, "r") as f: + dataset_content = f.read() + yield dataset_content + + +def _prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str, num_requests: int = 10): + """Prepare a synthetic dataset for benchmarking.""" + _DATASET_NAME = "synthetic_128_128.txt" + dataset_path = Path(temp_dir, _DATASET_NAME) + dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py") + script_dir = Path(root_dir, "benchmarks", "cpp") + + # Generate a small dataset to run a test - matching workload configuration + command = [ + "python3", + f"{dataset_tool}", + "--stdout", + "--tokenizer", + model_path_or_name, + "token-norm-dist", + "--input-mean", + "128", + "--output-mean", + "128", + "--input-stdev", + "0", + "--output-stdev", + "0", + "--num-requests", + str(num_requests), + ] + print(f"Running command: {' '.join(command)}") + result = subprocess.run( + command, cwd=str(script_dir), capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to prepare dataset: {result.stderr}") + # Grab the stdout and write it to a dataset file for passing to suite. + with open(dataset_path, "w") as dataset: + dataset.write(result.stdout) + return dataset_path + + +@pytest.mark.parametrize( + "allreduce_strategy", + [ + "AUTO", + "NCCL", + "ONESHOT", + pytest.param( + "TWOSHOT", + marks=pytest.mark.skip( + reason="TWOSHOT requires C++ fix for seq_len < tp_size fallback" + ), + ), + ], +) +def test_allreduce_strategies(llm_root, shared_dataset, allreduce_strategy): # noqa: F811 + """Test all AllReduceStrategy enum values with multi-GPU configuration. + + This test validates that all allreduce strategies defined in the AllReduceStrategy + enum work correctly with TP=2. The strategy is configured via the detect_sharding + transform config and automatically applied as a global variable. + + + Test has a 300 second timeout to prevent indefinite hangs. + Test will be skipped if fewer than 2 GPUs are available. + + Args: + llm_root: Root directory fixture + shared_dataset: Shared dataset fixture (prepared once for all test runs) + allreduce_strategy: Strategy to test (AllReduceStrategy enum values) + """ + # Fixed timeout for all strategies (5 minutes should be enough) + TEST_TIMEOUT_SECONDS = 300 + + model_name = "meta-llama/Llama-3.1-8B" + tp_size = 2 + max_batch_size = 256 + max_num_tokens = 8192 + + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip(f"Allreduce strategy test requires at least {tp_size} GPUs, skipping") + + with tempfile.TemporaryDirectory() as temp_dir: + # Write shared dataset to temp location + dataset_path = Path(temp_dir, "synthetic_128_128.txt") + with open(dataset_path, "w") as f: + f.write(shared_dataset) + + # Create configuration with allreduce strategy in transform config + extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml" + with open(extra_llm_api_options_path, "w") as f: + yaml.dump( + { + "model": model_name, + "max_batch_size": max_batch_size, + "max_num_tokens": max_num_tokens, + "max_seq_len": 256, + "transforms": { + "detect_sharding": { + "stage": "sharding", + "allreduce_strategy": allreduce_strategy, + }, + "compile_model": { + "stage": "compile", + "backend": "torch-cudagraph", + "cuda_graph_batch_sizes": [1, 16, 256], + }, + }, + }, + f, + ) + + # Run benchmark with specified allreduce strategy with timeout protection + runner = CliRunner() + args = [ + "--model", + model_name, + "throughput", + "--backend", + "_autodeploy", + "--dataset", + str(dataset_path), + "--extra_llm_api_options", + extra_llm_api_options_path, + "--tp", + str(tp_size), + "--max_batch_size", + str(max_batch_size), + "--max_num_tokens", + str(max_num_tokens), + ] + + try: + with timeout(TEST_TIMEOUT_SECONDS): + result = runner.invoke(main, args, catch_exceptions=False) + assert result.exit_code == 0, f"Benchmark failed with output: {result.output}" + except TimeoutError as e: + pytest.fail( + f"Test timed out after {TEST_TIMEOUT_SECONDS}s for strategy {allreduce_strategy}. " + f"This might indicate a hang (e.g., TWOSHOT without C++ fix). Error: {e}" + )