[HostIR] Print index definitions in HostIrContainer::print#5327
[HostIR] Print index definitions in HostIrContainer::print#5327samnordmann merged 2 commits intomainfrom
HostIrContainer::print#5327Conversation
|
Review updated until commit 8f4be43 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
Can you remind me why they aren't part of topLevelExprs? Analogously, imagine you write a C++ loop
|
I am not sure to understand your comment correctly. Index computations are not part of topLevelExprs just because they don't need to. When we call The example you wrote looks good to me, but I am not sure to understand what you suggest by it. The example I provided in the PR description explains well the use case for the present patch. |
That's fair enough and LGTM. I recall that for MultiDeviceExecutor we also have to find and When Hanlin worked on host IR JIT, we realized that finding what indices to invalidate at "run" time creates problems for host latency. So, for the FusionExecutorCache integration, I let host IR lowering find these index calculations at "compile" time and put them in the scope of the for loop. This is done on a separate code path so doesn't affect MultiDeviceExecutor. I didn't get a chance to check with you -- hence my question earlier. |
Ok, now I understand. I like the idea of what was done for host IR JIT. Do you have a pointer to the PR? It would make sense to do the same in MultiDeviceExecutor. |
|
!test |
Improve printing of HostIrContainer by printing the index computations which are not explicitly part of the `topLevelExprs`. Example from #5259 ``` %HostIrContainer { (T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false) T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR streamIdx in istreamIdx10{8}: SetCurrentStream to Stream ( streamIdx % numberOfStreams ) Synchronize Stream 0 FOR streamIdx in istreamIdx10{8}: SetCurrentStream to Stream ( streamIdx % numberOfStreams ) T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T4_g___bfloat[istreamIdx10{8}, iS11{128}, iS12{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx10{8}, index = i84 ) IF Manual ( ( ( 8 + ( rank - streamIdx ) ) % 8 ) == rank ): T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{8}, index = 0 ) T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = Set( T5_l___bfloat[iS13{128}, iS14{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming ) ELSE: ShareMemHandles(P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA), P2PCommunication 38 (type=send, buffer=T0_g___bfloat[ideviceIdx.x0{8}, iS1{128}, iS2{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i90, backend=CUDA) P2PCommunication 37 (type=recv, buffer=T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=i84, backend=CUDA) Wait Communication 38 Wait Communication 37 T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g___bfloat[istreamIdx6{8}, iS7{128}, iS8{1024}, rS9{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = istreamIdx6{8}, index = i84 ) T7_l___bfloat[iS17{128}, iS18{1024}, rS19{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) = linear(T6_l___bfloat[iS15{128}, iS16{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g___bfloat[iS3{1024}, iS4{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) , T2_g___bfloat[iS5{1024}] (DeviceMesh{0 1 2 3 4 5 6 7}) ) SetCurrentStream to Stream 0 Synchronize Stream ( streamIdx % numberOfStreams ) } // %HostIrContainer Index definitions: i111 = streamIdx % numberOfStreams; i90 = i88 % 8; i32 = i30 * 1024; i30 = 8 * 128; i86 = rank - streamIdx; i82 = rank + streamIdx; i74 = 8 * 128; i76 = i74 * 1024; i84 = i82 % 8; i88 = 8 + i86; ```
Improve printing of HostIrContainer by printing the index computations which are not explicitly part of the
topLevelExprs.Example from #5259