Skip to content

P2p cuda lowering#5259

Merged
samnordmann merged 5 commits intomainfrom
p2p_cuda_lowering
Oct 8, 2025
Merged

P2p cuda lowering#5259
samnordmann merged 5 commits intomainfrom
p2p_cuda_lowering

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Sep 29, 2025

Adds a lowering path to generate a p2p ring pipeline backed by our recent cuda ipc backend. The performance look great, and even beats transformer engine for large matrix sizes, e.g., for TP columnwise (i.e. AG+Matmul), for m=32, k=16k, n=8k, the Throughput (in TFLOPs) of the different implementations reads as follows:

  • Fuser default, with nccl backend: 560 TFLOPs. This has the same perf as a baseline pytorch eager implementation
  • Fuser with p2p pipeline and cuda ipc backend: 678 TFLOPs
  • Transformer Engine: 660 TFLOPs
Screenshot 2025-09-29 at 16 29 42

This was measured using DDLB and this Fuser's branch, on a single 8*H100 DGX node

This PR is dependent on

The test written in the PR expresses a matmul

C = matmul(A,B), 
where 
- A [DIDx(d), M/d, K]
- B[K,N],
- C[Stream(d), M/d, N]

The generated host program is:

%HostIrContainer { (T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR streamIdx in istreamIdx10{8}:
    SetCurrentStream to Stream ( streamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR streamIdx in istreamIdx10{8}:
    SetCurrentStream to Stream ( streamIdx % numberOfStreams )
    T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i84 )
    IF Manual ( ( ( 8 + ( rank - streamIdx ) ) % 8 ) == rank ):
      T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{8}, index = 0 )
      T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      ShareMemHandles(P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA),
      P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA)
      P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA)
      Wait Communication 38
      Wait Communication 37
    T7_l___bfloat[iS17{128}, iS18{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i107 )
    T8_l___bfloat[iS19{128}, iS20{1024}, rS21{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx6{8}, index = i107 )
    T8_l___bfloat[iS19{128}, iS20{1024}, rS21{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = linear(T7_l___bfloat[iS17{128}, iS18{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      ,
          T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      )
    SetCurrentStream to Stream 0
    Synchronize Stream ( streamIdx % numberOfStreams )
} // %HostIrContainer

@github-actions
Copy link

github-actions bot commented Sep 29, 2025

Review updated until commit 0d7d4ce

Description

  • Add CUDA IPC backend support for P2P communication

  • Enhance stream parallel type with dynamic index handling

  • Introduce communicator backend dispatch in lowering logic

  • Support CUDA backend in multi-device overlap tests


Changes walkthrough 📝

Relevant files
Enhancement
lower.cpp
Pass parameters to stream parallel lowering                           

csrc/host_ir/lower.cpp

  • Pass HostIrLowerParams to StreamParallelType pass
  • Enable parameterized communication backend in lowering
  • +1/-1     
    stream_parallel_type.cpp
    Add backend-aware P2P communication lowering                         

    csrc/host_ir/pass/stream_parallel_type.cpp

  • Modify processForLoopBodies to accept communicator backend
  • Implement ring-style indexing for CUDA IPC backend
  • Add backend-specific P2P communication handling
  • Use ShareMemHandles for CUDA backend instead of coalescing
  • +71/-28 
    enum.cpp
    Expose CUDA communicator backend in Python                             

    python/python_direct/enum.cpp

  • Add CommunicatorBackend::kCuda enum value
  • Expose CUDA backend to Python interface
  • +2/-1     
    stream_parallel_type.h
    Parameterize StreamParallelType with lowering params         

    csrc/host_ir/pass/stream_parallel_type.h

  • Add HostIrLowerParams member to StreamParallelType
  • Update constructor to accept lowering parameters
  • +8/-0     
    Tests
    test_overlap.py
    Add CUDA backend to overlap tests                                               

    tests/python/multidevice/test_overlap.py

  • Extend test to cover CUDA backend
  • Parameterize tests with nccl and cuda backends
  • +3/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The indexing logic for tensor slicing in the p2p ring pipeline may incorrectly use for_loop->stop() as the modulus base when computing send_peer and tensor_index, which should likely be the number of devices in the communicator group rather than the loop bound. This could lead to incorrect peer indexing and broken communication patterns if the loop bound does not match the number of devices.

    auto tensor_index = communicator_backend == CommunicatorBackend::kCuda
        ? mod(add(my_device_id, for_loop->index()), for_loop->stop())
        : for_loop->index();
    Performance Issue

    The use of mod(add(for_loop->stop(), sub(my_device_id, for_loop->index())), for_loop->stop()) for computing send_peer may result in redundant modulo arithmetic and potential inefficiencies. A simpler and safer expression using mod(my_device_id - for_loop->index(), num_devices) should be considered, assuming correct device count is used.

    auto send_peer = (communicator_backend == CommunicatorBackend::kCuda)
        ? mod(add(for_loop->stop(), sub(my_device_id, for_loop->index())),
              for_loop->stop())
        : for_loop->index();
    Correctness Concern

    The condition for inserting StartCoalescing/EndCoalescing is tied to the kNccl backend, but similar coalescing behavior might be necessary for other backends to avoid ordering issues. The current logic assumes kCuda backend does not require coalescing, which may not hold under all network topologies or execution orders.

    if (communicator_backend == CommunicatorBackend::kNccl) {
      // Using Start/EndCoalescing here is important to 1) avoid hangs
      // because of a wrong global order of send/recv and 2) enjoy full
      // bi-directional bandwith.

    @samnordmann
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator

    • Transformer Engine: 660 TFLOPs

    Results look great! I couldn't find how you run TransformerEngine; otherwise I could check this by myself. I assume you've turned on their userbuffer overlapping as in https://github.com/NVIDIA/TransformerEngine/blob/66f9b3cbae214d521ac18883fe9a386b8893b179/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py#L50?

    @samnordmann
    Copy link
    Collaborator Author

    • Transformer Engine: 660 TFLOPs

    Results look great! I couldn't find how you run TransformerEngine; otherwise I could check this by myself. I assume you've turned on their userbuffer overlapping as in https://github.com/NVIDIA/TransformerEngine/blob/66f9b3cbae214d521ac18883fe9a386b8893b179/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py#L50?

    I think it's correct, but I still need to confirm my results and double-check with TE team that things are run the correct way. I used ddlb to run TE, check the refrence here: https://github.com/samnordmann/ddlb/blob/main/ddlb/primitives/TPColumnwise/transformer_engine.py

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann merged commit 6b60152 into main Oct 8, 2025
    63 of 68 checks passed
    @samnordmann samnordmann deleted the p2p_cuda_lowering branch October 8, 2025 14:12
    samnordmann added a commit that referenced this pull request Oct 8, 2025
    Improve printing of HostIrContainer by printing the index computations
    which are not explicitly part of the `topLevelExprs`.
    Example from #5259
    
    ```
    %HostIrContainer { (T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i84 )
        IF Manual ( ( ( 8 + ( rank - streamIdx ) ) % 8 ) == rank ):
          T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = HirAliasSelect( T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{8}, index = 0 )
          T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = Set( T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
        ELSE:
          ShareMemHandles(P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA),
          P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA)
          P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA)
          Wait Communication 38
          Wait Communication 37
        T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx6{8}, index = i84 )
        T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = linear(T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      ,
              T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      )
        SetCurrentStream to Stream 0
        Synchronize Stream ( streamIdx % numberOfStreams )
    } // %HostIrContainer
    
    Index definitions:
      i111 = streamIdx % numberOfStreams;
      i90 = i88 % 8;
      i32 = i30 * 1024;
      i30 = 8 * 128;
      i86 = rank - streamIdx;
      i82 = rank + streamIdx;
      i74 = 8 * 128;
      i76 = i74 * 1024;
      i84 = i82 % 8;
      i88 = 8 + i86;
    
    ```
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Adds a lowering path to generate a p2p ring pipeline backed by our
    recent cuda ipc backend. The performance look great, and even beats
    transformer engine for large matrix sizes, e.g., for TP columnwise (i.e.
    AG+Matmul), for m=32, k=16k, n=8k, the Throughput (in TFLOPs) of the
    different implementations reads as follows:
    - Fuser default, with nccl backend: 560 TFLOPs. This has the same perf
    as a baseline pytorch eager implementation
    - Fuser with p2p pipeline and cuda ipc backend: 678 TFLOPs
    - Transformer Engine: 660 TFLOPs
    
    
    <img width="786" height="473" alt="Screenshot 2025-09-29 at 16 29 42"
    src="https://github.com/user-attachments/assets/0bf34178-ccef-4d4d-abcf-3f4aa3704f69"
    />
    
    This was measured using [DDLB](https://github.com/samnordmann/ddlb) and
    [this Fuser's
    branch](https://github.com/NVIDIA/Fuser/tree/lower_to_cuda_ipc_p2p_rebased),
    on a single 8*H100 DGX node
    
    
    This PR is dependent on
    - #4466. Without the Allocation
    Cache, a rank might change the allocated buffer accross iteration.
    Besides being a performance issue, it can create a hang if the ipc cache
    is not hit uniformly accross rank. A long term better solution would be
    to use pytorch's recent symmetric allocator
    - (for performance only) #5325
    
    
    
    The test written in the PR expresses a matmul
    ```
    C = matmul(A,B), 
    where 
    - A [DIDx(d), M/d, K]
    - B[K,N],
    - C[Stream(d), M/d, N]
    ```
    The generated host program is:
    ```
    %HostIrContainer { (T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i84 )
        IF Manual ( ( ( 8 + ( rank - streamIdx ) ) % 8 ) == rank ):
          T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = HirAliasSelect( T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{8}, index = 0 )
          T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = Set( T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
        ELSE:
          ShareMemHandles(P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA),
          P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA)
          P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA)
          Wait Communication 38
          Wait Communication 37
        T7_l___bfloat[iS17{128}, iS18{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i107 )
        T8_l___bfloat[iS19{128}, iS20{1024}, rS21{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx6{8}, index = i107 )
        T8_l___bfloat[iS19{128}, iS20{1024}, rS21{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = linear(T7_l___bfloat[iS17{128}, iS18{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      ,
              T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      )
        SetCurrentStream to Stream 0
        Synchronize Stream ( streamIdx % numberOfStreams )
    } // %HostIrContainer
    
    ```
    tbqh pushed a commit that referenced this pull request Nov 12, 2025
    Improve printing of HostIrContainer by printing the index computations
    which are not explicitly part of the `topLevelExprs`.
    Example from #5259
    
    ```
    %HostIrContainer { (T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR streamIdx in istreamIdx10{8}:
        SetCurrentStream to Stream ( streamIdx % numberOfStreams )
        T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i84 )
        IF Manual ( ( ( 8 + ( rank - streamIdx ) ) % 8 ) == rank ):
          T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = HirAliasSelect( T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{8}, index = 0 )
          T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
             = Set( T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
        ELSE:
          ShareMemHandles(P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA),
          P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA)
          P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA)
          Wait Communication 38
          Wait Communication 37
        T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx6{8}, index = i84 )
        T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = linear(T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      ,
              T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})      )
        SetCurrentStream to Stream 0
        Synchronize Stream ( streamIdx % numberOfStreams )
    } // %HostIrContainer
    
    Index definitions:
      i111 = streamIdx % numberOfStreams;
      i90 = i88 % 8;
      i32 = i30 * 1024;
      i30 = 8 * 128;
      i86 = rank - streamIdx;
      i82 = rank + streamIdx;
      i74 = 8 * 128;
      i76 = i74 * 1024;
      i84 = i82 % 8;
      i88 = 8 + i86;
    
    ```
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants