Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 54 additions & 69 deletions benchmarks/kernels/benchmark_fused_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Benchmark for FlashInfer fused collective operations vs standard operations.

This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations

Usage with torchrun:
Expand All @@ -24,7 +24,6 @@

from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -52,11 +51,12 @@
try:
import flashinfer.comm as flashinfer_comm # type: ignore

if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
if not (
hasattr(flashinfer_comm, "allreduce_fusion")
and hasattr(flashinfer_comm, "create_allreduce_fusion_workspace")
):
flashinfer_comm = None
logger.warning(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
logger.warning("FlashInfer comm module found but missing allreduce_fusion API")
except ImportError:
flashinfer_comm = None
logger.warning("FlashInfer not found, only benchmarking standard operations")
Expand All @@ -75,18 +75,18 @@
}

# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR = None
_FI_WORKSPACE = None


def setup_flashinfer_workspace(
world_size: int,
rank: int,
hidden_dim: int,
max_token_num: int,
use_fp32_lamport: bool = False,
dtype: torch.dtype,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global _FI_WORKSPACE_TENSOR
global _FI_WORKSPACE

if flashinfer_comm is None:
return None, None
Expand All @@ -96,33 +96,29 @@ def setup_flashinfer_workspace(
return None, None

try:
# Create IPC workspace
ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=world_size,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
group=get_tp_group().device_group,
use_fp32_lamport=use_fp32_lamport,
)
workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
)

_FI_WORKSPACE_TENSOR = workspace_tensor
return ipc_handles, workspace_tensor
_FI_WORKSPACE = workspace
return workspace
except Exception as e:
logger.error("Failed to setup FlashInfer workspace: %s", e)
return None, None
return None


def cleanup_flashinfer_workspace(ipc_handles):
def cleanup_flashinfer_workspace(workspace):
"""Cleanup FlashInfer workspace."""
if flashinfer_comm is None or ipc_handles is None:
if flashinfer_comm is None or workspace is None:
return

try:
group = get_tp_group().device_group
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
workspace.destroy()
except Exception as e:
logger.error("Failed to cleanup FlashInfer workspace: %s", e)

Expand All @@ -132,25 +128,15 @@ class FlashInferFusedAllReduceParams:

def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.max_token_num = max_token_num

def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
}

Expand All @@ -165,7 +151,7 @@ def flashinfer_fused_allreduce_rmsnorm(
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")

if norm_out is None:
Expand All @@ -174,18 +160,15 @@ def flashinfer_fused_allreduce_rmsnorm(
else:
residual_out = input_tensor

flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
Expand All @@ -207,7 +190,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
quant_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")

if norm_out is None:
Expand All @@ -216,18 +199,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
else:
residual_out = input_tensor

flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=None,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
Expand All @@ -250,7 +230,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
norm_out: torch.Tensor | None = None,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
if flashinfer_comm is None or _FI_WORKSPACE is None:
raise RuntimeError("FlashInfer not available or workspace not initialized")

if norm_out is None:
Expand All @@ -259,18 +239,15 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
else:
residual_out = input_tensor

flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
token_num=input_tensor.shape[0],
flashinfer_comm.allreduce_fusion(
input=input_tensor,
workspace=_FI_WORKSPACE,
pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
hidden_dim=input_tensor.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
allreduce_out=None,
quant_out=quant_out,
scale_out=output_scale,
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
Expand Down Expand Up @@ -1040,23 +1017,31 @@ def main():
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))

# Setup FlashInfer workspace if available
ipc_handles = None
workspace = None
allreduce_params = None

if flashinfer_comm is not None:
# Use the largest hidden dimension for workspace setup
max_element_size = max(torch.finfo(dt).bits // 8 for dt in dtypes)
workspace_dtype = (
torch.float32
if max_element_size == 4
else (torch.bfloat16 if torch.bfloat16 in dtypes else torch.float16)
)
max_num_token = _FI_MAX_SIZES.get(world_size) // (
args.hidden_dim * world_size * 2
args.hidden_dim * max_element_size
)

ipc_handles, workspace_tensor = setup_flashinfer_workspace(
world_size, rank, args.hidden_dim, max_num_token
workspace = setup_flashinfer_workspace(
world_size,
rank,
args.hidden_dim,
max_num_token,
dtype=workspace_dtype,
)

if workspace_tensor is not None:
if workspace is not None:
allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=world_size,
max_token_num=max_num_token,
)

Expand Down Expand Up @@ -1119,8 +1104,8 @@ def main():

finally:
# Cleanup
if ipc_handles is not None:
cleanup_flashinfer_workspace(ipc_handles)
if workspace is not None:
cleanup_flashinfer_workspace(workspace)

dist.barrier()

Expand Down
5 changes: 3 additions & 2 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ def ops_in_model_before(self):
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion",
"is not compiled with allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
test_model: torch.nn.Module,
Expand Down
Loading