Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 88 additions & 25 deletions benchmarks/kernels/benchmark_device_communicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
FlashInferAllReduce,
)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
register_nccl_symmetric_ops,
Expand All @@ -44,7 +47,7 @@
logger = init_logger(__name__)

# Default sequence lengths to benchmark
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]

# Fixed hidden size and dtype for all benchmarks
HIDDEN_SIZE = 8192
Expand Down Expand Up @@ -81,6 +84,7 @@ def __init__(
self.symm_mem_comm = None
self.symm_mem_comm_multimem = None
self.symm_mem_comm_two_shot = None
self.fi_ar_comm = None

self._init_communicators()

Expand Down Expand Up @@ -161,6 +165,22 @@ def _init_communicators(self):
)
self.symm_mem_comm_two_shot = None

try:
self.fi_ar_comm = FlashInferAllReduce(
group=self.cpu_group,
device=self.device,
)
if not self.fi_ar_comm.disabled:
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
else:
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
self.fi_ar_comm = None
except Exception as e:
logger.warning(
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
)
self.fi_ar_comm = None

def benchmark_allreduce(
self, sequence_length: int, num_warmup: int, num_trials: int
) -> dict[str, float]:
Expand All @@ -180,7 +200,8 @@ def benchmark_allreduce(
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"1stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
None, # no destroy function
)
)
# CustomAllreduce two-shot
Expand All @@ -190,7 +211,8 @@ def benchmark_allreduce(
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"2stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
None, # no destroy function
)
)

Expand All @@ -202,7 +224,8 @@ def benchmark_allreduce(
lambda t, c=comm: c.all_reduce(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
communicators.append(
Expand All @@ -211,7 +234,8 @@ def benchmark_allreduce(
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)

Expand All @@ -223,7 +247,8 @@ def benchmark_allreduce(
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)

Expand All @@ -235,29 +260,67 @@ def benchmark_allreduce(
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function needed
)
)

# Benchmark each communicator
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
# Set environment variable if needed
if env_var is not None:
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
else:
# Clear the environment variable to avoid interference
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)

latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
if self.fi_ar_comm is not None:
comm = self.fi_ar_comm
communicators.append(
(
"flashinfer_trtllm",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
lambda c=comm: c.destroy(),
)
)
if latency is not None:
results[name] = latency
communicators.append(
(
"flashinfer_mnnvl",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
lambda c=comm: c.destroy(),
)
)

# Benchmark each communicator
for (
name,
allreduce_fn,
should_use_fn,
context,
env_dict,
destroy_fn,
) in communicators:
# Save original values and apply new environment variables
saved_env = {key: os.environ.get(key) for key in env_dict}
for key, value in env_dict.items():
os.environ[key] = value
try:
latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
)
if latency is not None:
results[name] = latency
finally:
if destroy_fn is not None:
destroy_fn()
# Restore environment variables to their original state
for key, original_value in saved_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value

return results

Expand Down
Loading
Loading