Skip to content

Benchmark for distributed matmul with overlap#4326

Merged
samnordmann merged 1 commit intomainfrom
first_overlap_benchmark
May 20, 2025
Merged

Benchmark for distributed matmul with overlap#4326
samnordmann merged 1 commit intomainfrom
first_overlap_benchmark

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Apr 27, 2025

Add a benchmark unit test for AG+Matmul with/without overlap. The test uses pytest framework and defines a custom Timer based on Cuda Events.

Results obtained show significant performance benefits of pipelining with ucc for large matrices. Results obtained for DGX 8* H100 GPU:

--------------------------------------------------------------------------------------------------------------------------------------------- benchmark: 16 tests ---------------------------------------------------------------------------------------------------------------------------------------------
Name (time in ms)                                                                                                                                 Min                Max               Mean             StdDev             Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=4096-m=65536-s=1-backend_type=<CommunicatorBackend.nccl: 0>]       2.7351 (1.0)       2.9785 (1.0)       2.8138 (1.0)       0.0803 (6.58)      2.7779 (1.0)      0.1096 (5.56)          2;0  355.3866 (1.0)          10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=4096-m=65536-s=1-backend_type=<CommunicatorBackend.ucc: 1>]        3.0713 (1.12)      3.4526 (1.16)      3.1341 (1.11)      0.1131 (9.28)      3.1061 (1.12)     0.0199 (1.01)          1;1  319.0726 (0.90)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=4096-m=65536-s=8-backend_type=<CommunicatorBackend.nccl: 0>]       3.1446 (1.15)      5.8581 (1.97)      3.4676 (1.23)      0.8425 (69.11)     3.1833 (1.15)     0.0784 (3.98)          1;2  288.3866 (0.81)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=4096-m=65536-s=8-backend_type=<CommunicatorBackend.ucc: 1>]        3.2604 (1.19)      3.3357 (1.12)      3.2910 (1.17)      0.0235 (1.93)      3.2854 (1.18)     0.0339 (1.72)          2;0  303.8600 (0.86)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=4096-m=65536-s=8-backend_type=<CommunicatorBackend.ucc: 1>]        4.1977 (1.53)      4.2335 (1.42)      4.2148 (1.50)      0.0122 (1.0)       4.2161 (1.52)     0.0225 (1.14)          2;0  237.2606 (0.67)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=4096-m=65536-s=1-backend_type=<CommunicatorBackend.nccl: 0>]       4.7256 (1.73)      4.7770 (1.60)      4.7476 (1.69)      0.0191 (1.56)      4.7482 (1.71)     0.0383 (1.94)          5;0  210.6325 (0.59)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=4096-m=65536-s=1-backend_type=<CommunicatorBackend.ucc: 1>]        4.7307 (1.73)     41.2299 (13.84)     8.3965 (2.98)     11.5365 (946.41)    4.7438 (1.71)     0.0197 (1.0)           1;2  119.0973 (0.34)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=4096-m=65536-s=8-backend_type=<CommunicatorBackend.nccl: 0>]       5.1204 (1.87)      5.4672 (1.84)      5.3154 (1.89)      0.1009 (8.28)      5.3345 (1.92)     0.1097 (5.56)          4;0  188.1342 (0.53)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=65536-m=65536-s=8-backend_type=<CommunicatorBackend.ucc: 1>]      25.0230 (9.15)     25.8178 (8.67)     25.4180 (9.03)      0.3072 (25.20)    25.4278 (9.15)     0.5594 (28.38)         2;0   39.3422 (0.11)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=65536-m=65536-s=8-backend_type=<CommunicatorBackend.nccl: 0>]     33.9819 (12.42)    35.5789 (11.95)    35.2329 (12.52)     0.4613 (37.84)    35.3931 (12.74)    0.2792 (14.16)         1;1   28.3826 (0.08)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=65536-m=65536-s=1-backend_type=<CommunicatorBackend.ucc: 1>]      34.2042 (12.51)    34.8641 (11.71)    34.4020 (12.23)     0.1912 (15.69)    34.3466 (12.36)    0.2033 (10.31)         2;1   29.0681 (0.08)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float16-n=1024-k=65536-m=65536-s=1-backend_type=<CommunicatorBackend.nccl: 0>]     34.2332 (12.52)    34.5607 (11.60)    34.3961 (12.22)     0.1131 (9.28)     34.3828 (12.38)    0.1795 (9.10)          4;0   29.0731 (0.08)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=65536-m=65536-s=8-backend_type=<CommunicatorBackend.ucc: 1>]      47.2971 (17.29)    47.6112 (15.98)    47.4518 (16.86)     0.0900 (7.38)     47.4569 (17.08)    0.1091 (5.53)          3;0   21.0740 (0.06)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=65536-m=65536-s=8-backend_type=<CommunicatorBackend.nccl: 0>]     62.3246 (22.79)    68.5419 (23.01)    67.2244 (23.89)     2.1432 (175.82)   68.4183 (24.63)    2.9840 (151.38)        1;0   14.8755 (0.04)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=65536-m=65536-s=1-backend_type=<CommunicatorBackend.ucc: 1>]      65.0065 (23.77)    65.7736 (22.08)    65.4415 (23.26)     0.2386 (19.58)    65.4146 (23.55)    0.3241 (16.44)         3;0   15.2808 (0.04)         10           1
test_overlap_allgather_matmul_stream_outermost[dtype=torch.float32-n=1024-k=65536-m=65536-s=1-backend_type=<CommunicatorBackend.nccl: 0>]     65.0892 (23.80)    65.7982 (22.09)    65.3787 (23.23)     0.2185 (17.92)    65.3707 (23.53)    0.2645 (13.42)         3;0   15.2955 (0.04)         10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 27, 2025

Review updated until commit c798fe2

Description

  • Added benchmark for distributed matrix multiplication with overlap

  • Implemented custom CUDA event-based timer

  • Defined fusion for overlapping all-gather with matrix multiplication

  • Included parameterized tests for different backends and matrix dimensions


Changes walkthrough 📝

Relevant files
Enhancement
benchmark_overlap.py
Add benchmark for distributed matmul with overlap               

benchmarks/python/benchmark_overlap.py

  • Introduced CUDAEventTimer class for accurate GPU timing
  • Added benchmark_cuda_events_pedantic function to use custom timer
  • Defined OverlapAGMatmulStreamOutermost fusion for overlapping
    operations
  • Created MultideviceSettings class for multi-device execution settings
  • Implemented parameterized test
    test_overlap_allgather_matmul_stream_outermost
  • +243/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Environment Variable

    The environment variable UCC_CL_BASIC_TLS is set to nccl within the test function. This might not be ideal for all environments and could lead to unexpected behavior if the test is run in an environment where UCC is not expected to use NCCL.

    os.environ["UCC_CL_BASIC_TLS"] = "nccl"
    Output Validation

    The output validation is only performed if validate_output is True, which is not the default. This might lead to undetected issues in the benchmarked function.

    if validate_output:
        outputs, _ = fd.execute([inputs])
        out = outputs[0].cpu()
        assert out.dtype == dtype
        assert out.shape == torch.Size([s, d, m // (s * d), n])
        out_ref = torch.nn.functional.linear(x_unsharded, weight.cpu(), bias.cpu())
        torch.testing.assert_close(out, out_ref, rtol=float("inf"), atol=1e-1)
    Device Synchronization

    The CUDAEventTimer class does not handle device synchronization in the __call__ method, which could lead to incorrect timing if the GPU is not properly synchronized.

    def __call__(self):
        """Record timing events and compute elapsed time.
    
        Returns:
            float: Elapsed time in seconds
        """
        if self.is_running:
            self.end_event.record()
            torch.cuda.synchronize()
            elapsed_ms = self.start_event.elapsed_time(self.end_event)
            self.is_running = False
            return elapsed_ms / 1000.0  # Convert ms to seconds
        else:
            self.start_event.record()
            self.is_running = True
            return 0.0

    @samnordmann samnordmann changed the title [WIP] Benchmark for distributed matmul with overlap Benchmark for distributed matmul with overlap Apr 28, 2025
    @samnordmann samnordmann marked this pull request as ready for review April 28, 2025 13:02
    Copy link
    Member

    @nsarka nsarka left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Can you post the results? The benchmark looks good

    @samnordmann samnordmann force-pushed the first_overlap_benchmark branch 2 times, most recently from 67638f9 to 0de9e90 Compare May 13, 2025 13:27
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from Priya2698 May 13, 2025 16:12
    @wujingyue
    Copy link
    Collaborator

    cc @Priya2698 to make sure this is aligned with how you plan to set up Python benchmarks

    @Priya2698
    Copy link
    Collaborator

    As multidevice benchmarks grow, it would be useful to have them in the dashboard for visualization. See http://nv/eGB for examples.

    We can use pytest-benchmark which has some useful features around storing data, benchmark information (iobytes, configurations), dumping to files, warmup rounds etc. So benchmark_cuda_events can be converted to a Timer class (see https://github.com/NVIDIA/Fuser/blob/main/python/nvfuser/benchmark_utils.py) that essentially functions like a monotonic clock. MultiDeviceBenchmark can be a wrapper around the pytest-benchmark fixture similar to NVFBenchmark.

    This can unify the multidevice benchmarks with the existing utilities. I am not sure though, if pytest-benchmark can interfere with any multidevice calls, such as barrier. I would think it works since this PR already uses pytest.

    @samnordmann
    Copy link
    Collaborator Author

    As multidevice benchmarks grow, it would be useful to have them in the dashboard for visualization. See http://nv/eGB for examples.

    We can use pytest-benchmark which has some useful features around storing data, benchmark information (iobytes, configurations), dumping to files, warmup rounds etc. So benchmark_cuda_events can be converted to a Timer class (see https://github.com/NVIDIA/Fuser/blob/main/python/nvfuser/benchmark_utils.py) that essentially functions like a monotonic clock. MultiDeviceBenchmark can be a wrapper around the pytest-benchmark fixture similar to NVFBenchmark.

    This can unify the multidevice benchmarks with the existing utilities. I am not sure though, if pytest-benchmark can interfere with any multidevice calls, such as barrier. I would think it works since this PR already uses pytest.

    I fully agree with the idea. I started to do as you describe, but then got caught in debugging/exploring some strange performance behavior, so I simplified the structure. Tbh, this PR was not really ready for review :) Let me arrange the PR and ping you again when it is ready

    @samnordmann samnordmann force-pushed the first_overlap_benchmark branch from b1f3e76 to dfe4eaf Compare May 16, 2025 16:05
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann requested a review from wujingyue May 16, 2025 16:15
    @samnordmann
    Copy link
    Collaborator Author

    @Priya2698 @wujingyue the PR is ready for review. This PR now depends on another PR #4466 which resolves the strange benchmark results I obtained. The benchmark is now using pytest structure, and measures each iteration separately (as is preferable, as we discussed) through cuda events.

    @samnordmann samnordmann force-pushed the first_overlap_benchmark branch from e80c834 to c798fe2 Compare May 19, 2025 13:55
    @samnordmann
    Copy link
    Collaborator Author

    I removed the dependecy of this PR with respect to the allocation cache #4466
    Somehow I am not able right now to reproduce the allocation issue I was seeing before, and I am getting now good numbers (in the PR description) even without the allocation cache. I would like to merge the benchmark in the meantime and see later if the issue shows up and whether we should add the allocation cache.
    cc @Priya2698 @wujingyue @nsarka

    @samnordmann
    Copy link
    Collaborator Author

    !test

    return None

    # Set our custom CUDA event timer
    benchmark._timer = CUDAEventTimer()
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Did you run into issues around the precision of timer?
    If no timer precision is set, pytest-benchmark does its own calibration which can sometimes produce invalid results:

    # Externally set the precision to avoid timer calibration. Since the timer uses CUDA times,
    # calibration using subsequent timer calls produces invalid results.
    # https://github.com/ionelmc/pytest-benchmark/blob/728752d2976ef53fde7e40beb3e55f09cf4d4736/src/pytest_benchmark/timers.py#L15
    benchmark_fixture._precisions[benchmark_fixture._timer] = precision

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    no I didn't get any issue about the precision of timer. Do you want me to manually set the precision anyway ?

    @pytest.mark.parametrize("k", [2**12, 2**16])
    @pytest.mark.parametrize("n", [2**10])
    @pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
    def test_overlap_allgather_matmul_stream_outermost(
    Copy link
    Collaborator

    @Priya2698 Priya2698 May 19, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this similar to

    def test_overlap_allgather_matmul_stream_outermost(
    ?

    If so, we should remove that instance now.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    They are close but slightly different (e.g., the datatype) and would rather keep them separated because we might not want to test and benchmark the same instances

    Copy link
    Collaborator

    @Priya2698 Priya2698 May 20, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    They are close but slightly different (e.g., the datatype)

    bfloat16 can be added to the datatypes parameter you already have though

    we might not want to test and benchmark the same instances

    I am not clear, why would this be problematic? That test is also running a benchmark with validation similar to the current test case.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    They are close but slightly different (e.g., the datatype)

    bfloat16 can be added to the datatypes parameter you already have though

    we might not want to test and benchmark the same instances

    I am not clear, why would this be problematic? That test is also running a benchmark with validation similar to the current test case.

    Nothing is problematic. I just think it is useful to separate in general the instance used for validation and the one for performance. I thought this was the idea behind having two separate folders. We should allow the instances to diverge more in the future. Anyway, nothing crucial for now so if you think that is important let me know and I'll remove the test, and/or, remove the validation in the current benchmark.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Got it.
    If you intend to modify the test instance, we can keep it unchanged. In the current state, it is actually benchmarking as well, which seems redundant with the current benchmark addition. Maybe we can make the other one validation only? It can be a separate PR.

    Copy link
    Collaborator Author

    @samnordmann samnordmann May 20, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I agree we do not need to measure execution time for the "tests". Btw, I wouldn't trust the numbers there, which IIUC represent host wall clock. In our case, the numbers from the CI look wrong, so would consider them not relevant

    00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.nccl]       248.8578 (1.0)        686.4527 (1.0)        275.4573 (1.0)       86.5904 (1.0)        254.6739 (1.0)       9.0676 (1.0)           1;3  3,630.3265 (1.0)          25           1
    00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        256.7070 (1.03)       718.1512 (1.05)       284.6673 (1.03)      91.3450 (1.05)       263.6444 (1.04)      9.1677 (1.01)          1;3  3,512.8730 (0.97)         25           1
    00:00:12 test_overlap_allgather_matmul_stream_outermost[s=1-backend_type=CommunicatorBackend.ucc]        262.9953 (1.06)       669.2354 (1.0)        292.9457 (1.06)      80.4101 (1.0)        271.6640 (1.07)     17.0944 (1.43)          1;2  3,413.6025 (0.94)         25           1
    00:00:12 test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.nccl]     1,227.1553 (4.95)     2,761.1768 (4.13)     1,348.3307 (4.90)     307.3344 (3.82)     1,269.3135 (5.00)     62.8864 (5.27)          2;3    741.6578 (0.20)         25           1
    00:00:12 test_overlap_allgather_matmul_stream_outermost[s=8-backend_type=CommunicatorBackend.ucc]      1,237.6998 (4.99)     1,651.3430 (2.47)     1,275.0503 (4.63)      81.1988 (1.01)     1,257.0452 (4.95)     28.9301 (2.43)          1;2    784.2828 (0.22)         25           1
    

    @Priya2698
    Copy link
    Collaborator

    I removed the dependecy of this PR with respect to the allocation cache #4466 Somehow I am not able right now to reproduce the allocation issue I was seeing before, and I am getting now good numbers (in the PR description) even without the allocation cache. I would like to merge the benchmark in the meantime and see later if the issue shows up and whether we should add the allocation cache. cc @Priya2698 @wujingyue @nsarka

    Got it. How can we monitor if that issue is showing up again?

    @wujingyue wujingyue removed their request for review May 20, 2025 02:39
    @samnordmann
    Copy link
    Collaborator Author

    I removed the dependecy of this PR with respect to the allocation cache #4466 Somehow I am not able right now to reproduce the allocation issue I was seeing before, and I am getting now good numbers (in the PR description) even without the allocation cache. I would like to merge the benchmark in the meantime and see later if the issue shows up and whether we should add the allocation cache. cc @Priya2698 @wujingyue @nsarka

    Got it. How can we monitor if that issue is showing up again?

    By having the benchmark running nightly and monitoring the performance remains in an acceptable range. This could be set up in nvFuser's CI, but the plan is to also include this benchmark in my team's benchmark tool and have a nightly consistent check there.

    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    @samnordmann samnordmann merged commit 48c0931 into main May 20, 2025
    53 checks passed
    @samnordmann samnordmann deleted the first_overlap_benchmark branch May 20, 2025 17:04
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants