Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions benchmarks/python/benchmark_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import os
import pytest
import torch
import nvfuser
from nvfuser import FusionDefinition, CommunicatorBackend
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
from .core import BENCHMARK_CONFIG, clear_l2_cache


class CUDAEventTimer:
"""Custom CUDA event-based timer for accurate GPU timing.

This timer uses CUDA events to measure elapsed time between operations,
providing more accurate GPU timing than CPU-based timing methods.
"""

def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.is_running = False

def __call__(self):
"""Record timing events and compute elapsed time.

Returns:
float: Elapsed time in seconds
"""
if self.is_running:
self.end_event.record()
torch.cuda.synchronize()
elapsed_ms = self.start_event.elapsed_time(self.end_event)
self.is_running = False
return elapsed_ms / 1000.0 # Convert ms to seconds
else:
self.start_event.record()
self.is_running = True
return 0.0

def cleanup(self):
"""Ensure timer is not running and synchronize CUDA."""
if self.is_running:
self.end_event.record()
torch.cuda.synchronize()
self.is_running = False


def benchmark_cuda_events_pedantic(
benchmark, benchmark_fn, inputs, rounds, warmup_rounds
):
"""Wrapper for benchmark_cuda_events that uses pytest's pedantic method with CUDA events.

Args:
benchmark: pytest-benchmark fixture
benchmark_fn: Function to benchmark
inputs: List of inputs to pass to benchmark_fn
rounds: Number of rounds to run
warmup_rounds: Number of warmup rounds
"""

def setup():
clear_l2_cache()
return inputs, {}

def wrapped_fn(*args):
benchmark_fn(*args[0])
return None

# Set our custom CUDA event timer
benchmark._timer = CUDAEventTimer()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run into issues around the precision of timer?
If no timer precision is set, pytest-benchmark does its own calibration which can sometimes produce invalid results:

# Externally set the precision to avoid timer calibration. Since the timer uses CUDA times,
# calibration using subsequent timer calls produces invalid results.
# https://github.com/ionelmc/pytest-benchmark/blob/728752d2976ef53fde7e40beb3e55f09cf4d4736/src/pytest_benchmark/timers.py#L15
benchmark_fixture._precisions[benchmark_fixture._timer] = precision

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no I didn't get any issue about the precision of timer. Do you want me to manually set the precision anyway ?


benchmark.pedantic(
wrapped_fn,
setup=setup,
rounds=rounds,
warmup_rounds=warmup_rounds,
iterations=1,
)


class OverlapAGMatmulStreamOutermost(FusionDefinition):
"""Fusion definition for overlapping all-gather with matrix multiplication.

This fusion implements a matrix multiplication operation with overlapping
all-gather communication, using stream parallelism for the outermost dimension.
"""

def __init__(self, m, k, n, s, num_devices, communication_backend, dtype):
super().__init__(
use_multidevice_executor=True, backend_type=communication_backend
)
self.m = m
self.k = k
self.n = n
self.s = s
self._num_devices = num_devices
self.dtype = dtype

def definition(self) -> None:
m, k, n, s, d = self.m, self.k, self.n, self.s, self._num_devices
self.x = self.define_tensor(
shape=[s, d, m // (s * d), k],
contiguity=True,
dtype=torch_dtype_to_nvfuser_dtype(self.dtype),
)
self.weight = self.define_tensor(
shape=[n, k],
contiguity=True,
dtype=torch_dtype_to_nvfuser_dtype(self.dtype),
)
self.bias = self.define_tensor(
shape=[n], contiguity=True, dtype=torch_dtype_to_nvfuser_dtype(self.dtype)
)
self.out = self.ops.linear(self.x, self.weight, self.bias)
self.add_output(self.out)

def multidevice_schedule(self):
mesh = nvfuser.DeviceMesh(range(self._num_devices))
for tv in [self.x, self.weight, self.bias, self.out]:
self.sched._set_device_mesh(tv, mesh)
self.sched.parallelize(self.x, 1, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(self.out, 0, nvfuser.ParallelType.stream)


class MultideviceSettings:
"""Settings and utilities for multi-device execution."""

def __init__(self):
self._communicator = nvfuser.Communicator.instance()
torch.manual_seed(0)

@property
def communicator(self):
return self._communicator

@property
def size(self):
return self._communicator.size()

@property
def rank(self):
return self._communicator.rank()

@property
def local_size(self):
return self._communicator.local_size()

@property
def local_rank(self):
return self._communicator.local_rank()

def shard_tensor(
self, t: torch.Tensor, dim: int, mesh: nvfuser.DeviceMesh
) -> torch.Tensor:
assert t.is_cpu, (
"This is not strictly required but it's a general good practice "
"for unit tests to create unsharded data on CPU to reduce GPU "
"memory footprint."
)
return mesh.shard_tensor(t, dim, self.rank).cuda(self.rank)


@pytest.fixture
def multidevice_settings():
return MultideviceSettings()


@pytest.mark.mpi
@pytest.mark.parametrize(
"backend_type", [CommunicatorBackend.ucc, CommunicatorBackend.nccl]
)
@pytest.mark.parametrize("s", [1, 8])
@pytest.mark.parametrize("m", [2**16])
@pytest.mark.parametrize("k", [2**12, 2**16])
@pytest.mark.parametrize("n", [2**10])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_overlap_allgather_matmul_stream_outermost(
Copy link
Collaborator

@Priya2698 Priya2698 May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this similar to

def test_overlap_allgather_matmul_stream_outermost(
?

If so, we should remove that instance now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are close but slightly different (e.g., the datatype) and would rather keep them separated because we might not want to test and benchmark the same instances

Copy link
Collaborator

@Priya2698 Priya2698 May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are close but slightly different (e.g., the datatype)

bfloat16 can be added to the datatypes parameter you already have though

we might not want to test and benchmark the same instances

I am not clear, why would this be problematic? That test is also running a benchmark with validation similar to the current test case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are close but slightly different (e.g., the datatype)

bfloat16 can be added to the datatypes parameter you already have though

we might not want to test and benchmark the same instances

I am not clear, why would this be problematic? That test is also running a benchmark with validation similar to the current test case.

Nothing is problematic. I just think it is useful to separate in general the instance used for validation and the one for performance. I thought this was the idea behind having two separate folders. We should allow the instances to diverge more in the future. Anyway, nothing crucial for now so if you think that is important let me know and I'll remove the test, and/or, remove the validation in the current benchmark.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.
If you intend to modify the test instance, we can keep it unchanged. In the current state, it is actually benchmarking as well, which seems redundant with the current benchmark addition. Maybe we can make the other one validation only? It can be a separate PR.

Copy link
Collaborator Author

@samnordmann samnordmann May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we do not need to measure execution time for the "tests". Btw, I wouldn't trust the numbers there, which IIUC represent host wall clock. In our case, the numbers from the CI look wrong, so would consider them not relevant

00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       248.8578 (1.0)        686.4527 (1.0)        275.4573 (1.0)       86.5904 (1.0)        254.6739 (1.0)       9.0676 (1.0)           1;3  3,630.3265 (1.0)          25           1
00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        256.7070 (1.03)       718.1512 (1.05)       284.6673 (1.03)      91.3450 (1.05)       263.6444 (1.04)      9.1677 (1.01)          1;3  3,512.8730 (0.97)         25           1
00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        262.9953 (1.06)       669.2354 (1.0)        292.9457 (1.06)      80.4101 (1.0)        271.6640 (1.07)     17.0944 (1.43)          1;2  3,413.6025 (0.94)         25           1
00:00:12 test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     1,227.1553 (4.95)     2,761.1768 (4.13)     1,348.3307 (4.90)     307.3344 (3.82)     1,269.3135 (5.00)     62.8864 (5.27)          2;3    741.6578 (0.20)         25           1
00:00:12 test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      1,237.6998 (4.99)     1,651.3430 (2.47)     1,275.0503 (4.63)      81.1988 (1.01)     1,257.0452 (4.95)     28.9301 (2.43)          1;2    784.2828 (0.22)         25           1

benchmark,
multidevice_settings,
backend_type,
s,
m,
k,
n,
dtype,
validate_output=False,
):
"""Test overlapping all-gather with matrix multiplication using stream parallelism.

Args:
benchmark: pytest-benchmark fixture
multidevice_settings: Settings for multi-device execution
backend_type: Communication backend to use
s: Number of streams
m: Matrix dimension m
k: Matrix dimension k
n: Matrix dimension n
dtype: Data type for computation
validate_output: Whether to validate output against reference
"""
nvfuser.FusionCache.reset()

d = multidevice_settings.size
assert m % (s * d) == 0
os.environ["UCC_CL_BASIC_TLS"] = "nccl"
torch.cuda.set_device(multidevice_settings.local_rank)

# Create input tensors
x_unsharded = torch.testing.make_tensor(
s, d, m // (s * d), k, dtype=dtype, device="cpu"
)
x = multidevice_settings.shard_tensor(
x_unsharded, 1, nvfuser.DeviceMesh(range(multidevice_settings.size))
)
weight = torch.testing.make_tensor(n, k, dtype=dtype, device="cuda")
bias = torch.testing.make_tensor(n, dtype=dtype, device="cuda")
inputs = [x, weight, bias]

# Create fusion definition
fd = OverlapAGMatmulStreamOutermost(m, k, n, s, d, backend_type, dtype)

if validate_output:
outputs, _ = fd.execute([inputs])
out = outputs[0].cpu()
assert out.dtype == dtype
assert out.shape == torch.Size([s, d, m // (s * d), n])
out_ref = torch.nn.functional.linear(x_unsharded, weight.cpu(), bias.cpu())
torch.testing.assert_close(out, out_ref, rtol=float("inf"), atol=1e-1)

def benchmark_fn(*args):
outputs, _ = fd.execute(args)
return outputs[0]

benchmark_cuda_events_pedantic(
benchmark,
benchmark_fn,
[inputs],
warmup_rounds=BENCHMARK_CONFIG["warmup_rounds"],
rounds=BENCHMARK_CONFIG["rounds"],
)