Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 38 additions & 39 deletions vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@
destroy_fi_ar_workspace,
get_fi_ar_quant_workspace,
get_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
initialize_fi_ar_workspace,
)

ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
Expand Down Expand Up @@ -133,15 +131,23 @@ def call_trtllm_fused_allreduce_norm(

# Select workspace based on pattern: quant patterns use the
# trtllm quant workspace, non-quant patterns use the primary workspace.
if pattern_code in (
is_quant_pattern = pattern_code in (
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
):
workspace = get_fi_ar_quant_workspace()
else:
workspace = get_fi_ar_workspace()
)
get_workspace_fn = (
get_fi_ar_quant_workspace if is_quant_pattern else get_fi_ar_workspace
)
workspace = get_workspace_fn(
world_size=world_size,
rank=get_tensor_model_parallel_rank(),
max_token_num=max_token_num,
hidden_dim=hidden_size,
dtype=allreduce_in.dtype,
group=get_tp_group().device_group,
)
assert workspace is not None, (
"Flashinfer workspace must be initialized when using flashinfer"
"Flashinfer allreduce workspace must be initialized when using flashinfer"
)
assert flashinfer_comm is not None
if norm_out is None:
Expand Down Expand Up @@ -753,35 +759,29 @@ def __init__(self, config: VllmConfig) -> None:
scope="global",
)

for workspace_init_fn in [
initialize_fi_ar_workspace,
initialize_fi_ar_quant_workspace,
]:
try:
workspace_init_fn(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
else:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"AllReduce fusion pass will be disabled.",
e,
)
return
workspace_kwargs = dict(
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
group=self.group,
)
if get_fi_ar_workspace(**workspace_kwargs) is None:
logger.warning_once(
"Failed to initialize Flashinfer allreduce workspace. "
"Flashinfer allreduce-norm fusion will be disabled."
)
return

self.supports_quant_fusion = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should warn if this failed as well. Lack of quant fusion means lower perf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. will add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. let me know if anything is missing.

get_fi_ar_quant_workspace(**workspace_kwargs) is not None
)
if not self.supports_quant_fusion:
logger.warning_once(
"Failed to initialize Flashinfer allreduce workspace. "
"Flashinfer allreduce-norm-quant fusion will be disabled."
)

self.allreduce_params = FlashInferFusedAllReduceParams(
world_size=self.tp_size,
Expand All @@ -793,9 +793,8 @@ def __init__(self, config: VllmConfig) -> None:

@enable_fake_mode
def register_patterns(self) -> None:
supports_quantization = get_fi_ar_quant_workspace() is not None
for epsilon in [1e-5, 1e-6]:
if supports_quantization:
if self.supports_quant_fusion:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
Expand Down
176 changes: 90 additions & 86 deletions vllm/distributed/device_communicators/flashinfer_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,27 @@
except ImportError:
pass

# Global workspace for standalone allreduce and non-quant ar+rms fusion
# Workspace for standalone allreduce and non-quant ar+rms fusion
_fi_ar_workspace = None
# Extra workspace for quant fusion patterns (only supported by trtllm backend)
# Only created if primary workspace is not already trtllm
_fi_ar_quant_workspace = None


def get_fi_ar_workspace():
return _fi_ar_workspace


def get_fi_ar_quant_workspace():
return _fi_ar_quant_workspace


def initialize_fi_ar_workspace(
def _create_workspace(
backend: str,
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
"""
Initialize the workspace if not already initialized.

Currently, this function is called by either the AllReduceFusionPass
or the FlashInferAllReduce backend for standalone allreduce.
If the fusion pass is enabled via
--compilation-config.pass_config.fuse_allreduce_rms=true,
it will create the workspace first, and the standalone backend
will reuse the workspace. Otherwise, the standalone backend will
create the workspace.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return

backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
):
"""Create a flashinfer allreduce workspace, returning None on failure."""
comm_backend = TorchDistBackend(group=group)
rng_state = random.getstate()
try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
Expand All @@ -81,9 +58,22 @@ def initialize_fi_ar_workspace(
dtype=dtype,
comm_backend=comm_backend,
)
except Exception as e:
if "multicast" in str(e).lower():
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"This is expected on GPUs without NVSwitch (e.g., NVLink "
"bridge-only or PCIe topologies).",
e,
)
else:
logger.warning_once(
"Failed to initialize FlashInfer All Reduce workspace: %s.",
e,
)
return None
finally:
random.setstate(rng_state)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
Expand All @@ -94,70 +84,84 @@ def initialize_fi_ar_workspace(
hidden_dim,
dtype,
)
return workspace


def get_fi_ar_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
):
"""
Return the allreduce workspace for non-quant patterns, initializing if needed.

Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce
for standalone allreduce. Backend is controlled by
VLLM_FLASHINFER_ALLREDUCE_BACKEND env var.
"""
global _fi_ar_workspace
if _fi_ar_workspace is not None:
return _fi_ar_workspace

backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND

# Reuse the quant workspace if it was already created with the same backend
if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend:
_fi_ar_workspace = _fi_ar_quant_workspace
return _fi_ar_workspace

_fi_ar_workspace = _create_workspace(
backend, world_size, rank, max_token_num, hidden_dim, dtype, group
)
return _fi_ar_workspace


def initialize_fi_ar_quant_workspace(
def get_fi_ar_quant_workspace(
world_size: int,
rank: int,
max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
group: ProcessGroup,
) -> None:
):
"""
Initialize the workspace used by quantization fusion patterns.
Return the allreduce workspace for quant patterns, initializing if needed.

Currently this always creates a workspace for trtllm backend as only it
supports quantization fusion (FP8/FP4). If the primary workspace
is already trtllm, the quant workspace aliases to it.
Always uses trtllm backend as it is the only one supporting quantization
fusion (FP8/FP4).
"""
global _fi_ar_quant_workspace
if _fi_ar_quant_workspace is not None:
return
return _fi_ar_quant_workspace

# If primary workspace is already trtllm, reuse it
# Reuse the non-quant workspace if it was already created with trtllm
if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
_fi_ar_quant_workspace = _fi_ar_workspace
return
return _fi_ar_quant_workspace

comm_backend = TorchDistBackend(group=group)
_fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
assert _fi_ar_quant_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=trtllm, "
"world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
world_size,
rank,
max_token_num,
hidden_dim,
dtype,
_fi_ar_quant_workspace = _create_workspace(
"trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group
)
return _fi_ar_quant_workspace


_fi_ar_workspace_lock = threading.Lock()


def destroy_fi_ar_workspace():
global _fi_ar_workspace
global _fi_ar_quant_workspace
global _fi_ar_workspace, _fi_ar_quant_workspace
with _fi_ar_workspace_lock:
if (
_fi_ar_quant_workspace is not None
and _fi_ar_quant_workspace is not _fi_ar_workspace
):
_fi_ar_quant_workspace.destroy()
_fi_ar_quant_workspace = None
is_alias = _fi_ar_workspace is _fi_ar_quant_workspace

if _fi_ar_workspace is not None:
_fi_ar_workspace.destroy()
_fi_ar_workspace = None
if _fi_ar_quant_workspace is not None and not is_alias:
_fi_ar_quant_workspace.destroy()

_fi_ar_workspace = _fi_ar_quant_workspace = None


atexit.register(destroy_fi_ar_workspace)
Expand Down Expand Up @@ -209,29 +213,21 @@ def __init__(

def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
"""Ensure the all reduce workspace is initialized."""
if get_fi_ar_workspace() is not None:
return True
if self.max_num_tokens == 0:
element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
try:
initialize_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
return True
except Exception as e:
logger.warning(
"Failed to initialize FlashInfer All Reduce workspace: %s. "
"FlashInfer All Reduce will be disabled.",
e,
)
workspace = get_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=dtype,
group=self.group,
)
if workspace is None:
self.disabled = True
return False
return True

def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
if self.disabled:
Expand All @@ -257,7 +253,15 @@ def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
return self._ensure_workspace(hidden_dim, input_tensor.dtype)

def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
workspace = get_fi_ar_workspace()
_, hidden_dim = input_tensor.shape
workspace = get_fi_ar_workspace(
world_size=self.world_size,
rank=self.rank,
max_token_num=self.max_num_tokens,
hidden_dim=hidden_dim,
dtype=input_tensor.dtype,
group=self.group,
)
return flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=workspace,
Expand Down
Loading