Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,102 @@ def all_reduce_fusion_pass_on_test_model(
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass


@multi_gpu_test(num_gpus=4)
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skip(
reason="Disabled until flashinfer fixes device_idx=tp_rank "
"in SymmDeviceMemory (wrong GPU in DP+TP configurations)",
)
def test_all_reduce_fusion_pass_dp_tp():
"""Test AllReduceFusionPass with DP=2, TP=2 (4 GPUs total).

Regression test for https://github.com/vllm-project/vllm/issues/34401
where workspace creation used the global process group instead of the
TP-scoped group, causing NCCL errors in DP+TP configurations.
"""
torch.multiprocessing.spawn(
all_reduce_fusion_pass_on_test_model_dp_tp,
args=(4,),
nprocs=4,
)


def all_reduce_fusion_pass_on_test_model_dp_tp(
local_rank: int,
world_size: int,
):
tp_size = 2
dtype = torch.bfloat16
hidden_size = 64
batch_size = 8
seq_len = 8

set_random_seed(0)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12346",
}
)

init_distributed_environment()

# Create vllm_config with dp_size=2 BEFORE initialize_model_parallel,
# because initialize_model_parallel reads data_parallel_size from
# the current vllm config.
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=[]
)
)
vllm_config.compilation_config.pass_config = PassConfig(
fuse_allreduce_rms=True, eliminate_noops=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank
vllm_config.parallel_config.data_parallel_size = 2

model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
# With dp=2, tp=2 on 4 ranks, this creates:
# TP groups: [0,1], [2,3]
# DP groups: [0,2], [1,3]
initialize_model_parallel(tensor_model_parallel_size=tp_size)

all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
)

token_num = batch_size * seq_len
model = TestAllReduceRMSNormModel(hidden_size, token_num)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)

assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
44 changes: 42 additions & 2 deletions vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,41 @@
_flashinfer_comm, "create_allreduce_fusion_workspace"
):
flashinfer_comm = _flashinfer_comm
from flashinfer.comm.mnnvl import TorchDistBackend

class _TPCommBackend(TorchDistBackend):
"""CommBackend scoped to the TP process group.

Fixes two flashinfer issues:
1. TorchDistBackend.bcast passes a group-local root as a
global rank to broadcast_object_list. We use group_src
instead.
2. IPC socket opIds collide across TP groups because
random.randint produces identical values under vllm's
deterministic seeding. We offset the opId by the
global rank of the group root.
"""

def __init__(self, group):
super().__init__(group=group)
self._global_root = self._dist.get_global_rank(group, 0)

def bcast(self, data, root=0):
object_list = [data]
self._dist.broadcast_object_list(
Comment thread
haosdent marked this conversation as resolved.
object_list, group_src=root, group=self._group
)
result = object_list[0]
# Offset opId by global root rank so each TP group
# gets a unique IPC socket path. Only opIds (int)
# flow through bcast in the TRTLLM backend path.
if isinstance(result, int):
result += self._global_root
return result

except ImportError:
pass

logger = init_logger(__name__)

if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default

Expand Down Expand Up @@ -687,6 +717,15 @@ def __init__(self, config: VllmConfig) -> None:
if self.tp_size <= 1:
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
return
if config.parallel_config.data_parallel_size > 1:
# flashinfer uses device_idx=tp_rank in SymmDeviceMemory,
# which maps to the wrong GPU for DP groups > 0.
# See: https://github.com/vllm-project/vllm/issues/34401
logger.warning_once(
"AllReduce fusion pass is disabled for DP+TP due to "
"a flashinfer device assignment limitation."
)
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass"
)
Expand Down Expand Up @@ -736,6 +775,7 @@ def __init__(self, config: VllmConfig) -> None:
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
comm_backend=_TPCommBackend(group=self.group),
)

global _FI_WORKSPACE
Expand Down