Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions benchmarks/kernels/benchmark_fused_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,18 @@ def run_benchmarks(

rms_eps = 1e-6
results = {}
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
use_oneshot_options = [False] if no_oneshot else [True, False]

# Create RMSNorm and QuantFP8 layers once for native benchmarks

if "none" in quant_modes:
# Standard AllReduce + RMSNorm
# Re-create VllmFusedAllreduce per config so CustomOp binds the
# correct forward method (native vs custom kernel).
for custom_op in ["-rms_norm", "+rms_norm"]:
with set_current_vllm_config(
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
suffix = (
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
)
Expand All @@ -438,6 +438,7 @@ def run_benchmarks(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm,
fullgraph=True,
Expand Down Expand Up @@ -482,7 +483,7 @@ def run_benchmarks(
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
)
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
suffix += (
op_suffix = suffix + (
"_custom_quant_fp8"
if "+" in quant_fp8_custom_op
else "_native_quant_fp8"
Expand All @@ -495,16 +496,17 @@ def run_benchmarks(
)
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
input_tensor,
residual=residual,
scale_factor=scale_fp8,
)
results[f"standard_allreduce{suffix}"] = time_ms
results[f"standard_allreduce{op_suffix}"] = time_ms
except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results[f"standard_allreduce{suffix}"] = float("inf")
results[f"standard_allreduce{op_suffix}"] = float("inf")

# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with set_current_vllm_config(
Expand All @@ -515,6 +517,7 @@ def run_benchmarks(
)
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
fullgraph=True,
Expand Down Expand Up @@ -580,6 +583,7 @@ def run_benchmarks(
)
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
input_tensor,
Expand All @@ -598,6 +602,7 @@ def run_benchmarks(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
):
try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
fullgraph=True,
Expand Down