Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,13 @@ class AllreduceOp
// MIN_LATENCY.
if (mStrategy != AllReduceStrategyType::AUTO)
{
// Check TWOSHOT constraint: seq_len >= tp_size
if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size())
{
TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT",
seq_len, mGroup.size());
return AllReduceStrategyType::ONESHOT;
}
return mStrategy;
}
else
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ transforms:
sharding_source: ['heuristic']
support_partial_config: true
sharding_dims: ['tp', 'ep', 'bmm']
allreduce_strategy: 'AUTO'
requires_shape_prop: true
# TODO: (hg) need to ensure run_shape_prop after sharding.
sharding_transform_executor:
Expand Down
44 changes: 40 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@
# warmup causes hangs due to workspace allocation with CPU synchronization
_allreduce_cache = {}

# Global AllReduce Strategy Configuration
# =========================================
# This global variable controls which allreduce implementation is used across
# all distributed operations in AutoDeploy. It's set once at initialization
# time via set_allreduce_strategy() and remains constant during execution.
_global_allreduce_strategy = AllReduceStrategy.AUTO

def set_allreduce_strategy(strategy: AllReduceStrategy):
"""Set the global allreduce strategy for all distributed operations.

This should be called once during initialization, before any distributed
operations are executed. All subsequent allreduce calls will use this strategy.

Note:
This clears the allreduce cache to ensure new operations use the updated strategy.
Call this before any model compilation or CUDA graph capture.
"""
global _global_allreduce_strategy
_global_allreduce_strategy = strategy
# Clear cache when strategy changes to force recreation with new strategy
_allreduce_cache.clear()

def get_allreduce_strategy() -> AllReduceStrategy:
"""Get the current global allreduce strategy.

Returns:
The currently configured AllReduceStrategy enum value.
"""
return _global_allreduce_strategy

def trtllm_allgather(tensor, dim, sizes=None):
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
Expand All @@ -22,13 +52,13 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."

# Cache key includes rank, world_size, and dtype to handle different configurations
cache_key = (rank, world_size, tensor.dtype)
# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
cache_key = (rank, world_size, tensor.dtype, _global_allreduce_strategy)
if cache_key not in _allreduce_cache:
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.AUTO for optimal performance
# Use the configured global strategy
_allreduce_cache[cache_key] = AllReduce(
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
mapping=p_config, strategy=_global_allreduce_strategy, dtype=tensor.dtype
)

torch_op = _allreduce_cache[cache_key]
Expand Down Expand Up @@ -59,6 +89,12 @@ def fused_allreduce_residual_rmsnorm_fake(
TRTLLM_OP_AVAILABLE = True
except ImportError:

def set_allreduce_strategy(strategy):
raise ImportError("TRT-LLM is not available.")

def get_allreduce_strategy():
raise ImportError("TRT-LLM is not available.")

def trtllm_allgather(tensor, dim, sizes=None):
raise ImportError("TRT-LLM is not available.")

Expand Down
46 changes: 45 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from typing import DefaultDict, Dict, List, Set, Tuple, Type

import torch
from pydantic import Field
from pydantic import Field, field_validator
from torch.fx import GraphModule, Node

from .....functional import AllReduceStrategy
from ...models.factory import ModelFactory, ShardingConfigSource
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
Expand Down Expand Up @@ -149,6 +150,32 @@ class ShardingTransformConfig(TransformConfig):
sharding_dims: List[ShardingDim] = Field(
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
)
allreduce_strategy: AllReduceStrategy = Field(
default=AllReduceStrategy.AUTO,
description="AllReduce strategy for distributed operations. Options: AUTO (automatic selection), "
"NCCL (NCCL-based), ONESHOT (single-phase fusion kernel), TWOSHOT (two-phase fusion kernel), "
"MIN_LATENCY (minimum latency heuristic), LOWPRECISION (low precision allreduce), "
"UB (unified buffer), MNNVL (multi-node NVLINK), NCCL_SYMMETRIC (NCCL symmetric). "
"This is set as a global variable during transform application.",
)

@field_validator("allreduce_strategy", mode="before")
@classmethod
def _validate_allreduce_strategy(cls, v):
"""Convert string names like 'AUTO' or 'ONESHOT' to AllReduceStrategy enum."""
if isinstance(v, AllReduceStrategy):
return v
if isinstance(v, str):
try:
return AllReduceStrategy[v]
except KeyError:
raise ValueError(
f"Invalid allreduce strategy: {v}. "
f"Valid options: {', '.join(s.name for s in AllReduceStrategy)}"
)
if isinstance(v, int):
return AllReduceStrategy(v)
return v


@TransformRegistry.register("detect_sharding")
Expand Down Expand Up @@ -186,6 +213,23 @@ def _apply(
local_rank, world_size = shared_config.local_rank, shared_config.world_size
# world_size = 2

# Configure global allreduce strategy from transform config
# This is set once during sharding transform and used by all distributed operations
if hasattr(self.config, "allreduce_strategy"):
try:
from ...distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy

if TRTLLM_OP_AVAILABLE:
# config.allreduce_strategy is already an AllReduceStrategy enum
set_allreduce_strategy(self.config.allreduce_strategy)
if self.config.allreduce_strategy != AllReduceStrategy.AUTO:
ad_logger.info(
f"Global allreduce strategy configured from transform: "
f"{self.config.allreduce_strategy.name}"
)
except (ImportError, AttributeError) as e:
ad_logger.warning(f"Failed to set allreduce strategy: {e}")

if world_size < 2:
ad_logger.info("Skipping sharding for single device")
return gm, TransformInfo(
Expand Down
37 changes: 34 additions & 3 deletions tensorrt_llm/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,44 @@ def set_workspace_tensor(self,
@staticmethod
def max_workspace_size_auto(tp_size: int,
support_deterministic=True) -> int:
"""Calculate workspace size for allreduce fusion kernel.

The workspace is used for lamport buffers in the fusion kernel.
Required size calculation:
- Each GPU needs 3 sub-buffers (for triple buffering)
- Each sub-buffer stores: max_num_tokens * hidden_size * dtype_size (bf16=2)
- The lamport allocation multiplies by tp_size, so:
lamport_size = 3 * size * tp_size (per GPU)

Example: Llama 8B (hidden=4096), max_tokens=8192, bf16, TP=4
- Data per sub-buffer: 8192 * 4096 * 2 = 64 MiB
- Total lamport: 3 * 64MB * 4 = 768 MiB per GPU
- Required 'size' parameter: 64 MiB (gets multiplied by tp_size in allocation)

Default (67,108,864 = 64 MiB) supports:
- Models up to hidden_size=4096 with max_num_tokens=8192
- Or hidden_size=8192 with max_num_tokens=4096

Override with TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE env var if needed for larger models.
"""
if force_all_reduce_deterministic() and support_deterministic:
workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE",
"1000000000")
return int(workspace_size)
if tp_size <= 2:
return 16_000_000
return 8_000_000

# Allow override via environment variable for edge cases
workspace_size_env = os.getenv("TRTLLM_ALLREDUCE_FUSION_WORKSPACE_SIZE")
if workspace_size_env:
size = int(workspace_size_env)
logger.info(
f"Using custom allreduce fusion workspace size: {size} bytes ({size / (1024**2):.1f} MiB)"
)
return size

# Default: 64 MiB - supports most common model configurations
# Increase via env var if you see CUDA illegal memory access errors with large models
default_size = 67_108_864 # Exactly 64 MiB
return default_size

@staticmethod
def max_workspace_size_lowprecision(tp_size: int) -> int:
Expand Down
Loading