diff --git a/comms/torchcomms/rcclx/HipApi.cpp b/comms/torchcomms/rcclx/HipApi.cpp index 4b41799a8..dcdab59b4 100644 --- a/comms/torchcomms/rcclx/HipApi.cpp +++ b/comms/torchcomms/rcclx/HipApi.cpp @@ -49,8 +49,7 @@ hipError_t DefaultHipApi::streamWaitEvent( return hipStreamWaitEvent(stream, event, flags); } -hipStream_t DefaultHipApi::getCurrentHIPStreamMasqueradingAsCUDA( - int device_index) { +hipStream_t DefaultHipApi::getCurrentCUDAStream(int device_index) { #ifdef HIPIFY_V2 return at::cuda::getCurrentCUDAStream(device_index).stream(); #else diff --git a/comms/torchcomms/rcclx/HipApi.hpp b/comms/torchcomms/rcclx/HipApi.hpp index afda1cfe9..8036c972d 100644 --- a/comms/torchcomms/rcclx/HipApi.hpp +++ b/comms/torchcomms/rcclx/HipApi.hpp @@ -57,8 +57,9 @@ class HipApi { [[nodiscard]] virtual hipError_t streamDestroy(hipStream_t stream) = 0; [[nodiscard]] virtual hipError_t streamWaitEvent(hipStream_t stream, hipEvent_t event, unsigned int flags) = 0; - virtual hipStream_t getCurrentHIPStreamMasqueradingAsCUDA( - int device_index) = 0; + // Note: Named getCurrentCUDAStream because hipify transforms + // getCurrentHIPStreamMasqueradingAsCUDA to getCurrentCUDAStream + virtual hipStream_t getCurrentCUDAStream(int device_index) = 0; [[nodiscard]] virtual hipError_t streamSynchronize(hipStream_t stream) = 0; [[nodiscard]] virtual hipError_t threadExchangeStreamCaptureMode( enum hipStreamCaptureMode* mode) = 0; @@ -113,7 +114,7 @@ class DefaultHipApi : public HipApi { hipStream_t stream, hipEvent_t event, unsigned int flags) override; - hipStream_t getCurrentHIPStreamMasqueradingAsCUDA(int device_index) override; + hipStream_t getCurrentCUDAStream(int device_index) override; [[nodiscard]] hipError_t streamSynchronize(hipStream_t stream) override; [[nodiscard]] hipError_t threadExchangeStreamCaptureMode( enum hipStreamCaptureMode* mode) override; diff --git a/comms/torchcomms/rcclx/TorchCommRCCLXBootstrap.cpp b/comms/torchcomms/rcclx/TorchCommRCCLXBootstrap.cpp index 66ef4c5c2..8477c0777 100644 --- a/comms/torchcomms/rcclx/TorchCommRCCLXBootstrap.cpp +++ b/comms/torchcomms/rcclx/TorchCommRCCLXBootstrap.cpp @@ -170,8 +170,7 @@ void TorchCommRCCLXBootstrap::cleanupTCPStore(ncclComm_t nccl_comm) { // object. store_.reset(); - auto stream = - hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index()); + auto stream = hip_api_->getCurrentCUDAStream(device_.index()); ncclResult_t result = rcclx_api_->allReduce( barrier_buffer_, barrier_buffer_, diff --git a/comms/torchcomms/rcclx/TorchCommRCCLXUtils.cpp b/comms/torchcomms/rcclx/TorchCommRCCLXUtils.cpp index 45aa52624..69c520595 100644 --- a/comms/torchcomms/rcclx/TorchCommRCCLXUtils.cpp +++ b/comms/torchcomms/rcclx/TorchCommRCCLXUtils.cpp @@ -331,7 +331,7 @@ hipStream_t TorchCommRCCLX::getOperationStream(bool async_op) { if (async_op) { // Get current PyTorch CUDA stream for this device hipStream_t current_stream = - hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index()); + hip_api_->getCurrentCUDAStream(device_.index()); // Record event on current stream and wait for it on internal stream HIP_CHECK( @@ -347,7 +347,7 @@ hipStream_t TorchCommRCCLX::getOperationStream(bool async_op) { return internal_stream_; } else { // Use the current PyTorch CUDA stream for synchronous operations - return hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index()); + return hip_api_->getCurrentCUDAStream(device_.index()); } } diff --git a/comms/torchcomms/rcclx/TorchWorkRCCLX.cpp b/comms/torchcomms/rcclx/TorchWorkRCCLX.cpp index 1a481489c..a6f41d6aa 100644 --- a/comms/torchcomms/rcclx/TorchWorkRCCLX.cpp +++ b/comms/torchcomms/rcclx/TorchWorkRCCLX.cpp @@ -156,8 +156,7 @@ void TorchWorkRCCLX::wait() { // Get the current stream using the device from the comm object hipStream_t current_stream = - comm_->getHipApi()->getCurrentHIPStreamMasqueradingAsCUDA( - comm_->device_.index()); + comm_->getHipApi()->getCurrentCUDAStream(comm_->device_.index()); // Add a dependency from the work's stream to the current stream // This makes the current stream wait for the end_event_ recorded on the diff --git a/comms/torchcomms/rcclx/tests/unit/cpp/TorchCommRCCLXWorkQueueTest.cpp b/comms/torchcomms/rcclx/tests/unit/cpp/TorchCommRCCLXWorkQueueTest.cpp index ec933d23b..aa13f88b7 100644 --- a/comms/torchcomms/rcclx/tests/unit/cpp/TorchCommRCCLXWorkQueueTest.cpp +++ b/comms/torchcomms/rcclx/tests/unit/cpp/TorchCommRCCLXWorkQueueTest.cpp @@ -114,7 +114,7 @@ class TorchCommRCCLXWorkQueueTest : public ::testing::Test { ON_CALL(*hip_mock_, streamSynchronize(_)).WillByDefault(Return(hipSuccess)); ON_CALL(*hip_mock_, streamWaitEvent(_, _, _)) .WillByDefault(Return(hipSuccess)); - ON_CALL(*hip_mock_, getCurrentHIPStreamMasqueradingAsCUDA(_)) + ON_CALL(*hip_mock_, getCurrentCUDAStream(_)) .WillByDefault(Return(current_stream_)); ON_CALL(*hip_mock_, getStreamPriorityRange(_, _)) .WillByDefault(DoAll( diff --git a/comms/torchcomms/rcclx/tests/unit/cpp/mocks/HipMock.hpp b/comms/torchcomms/rcclx/tests/unit/cpp/mocks/HipMock.hpp index e0b80baf6..815e55c16 100644 --- a/comms/torchcomms/rcclx/tests/unit/cpp/mocks/HipMock.hpp +++ b/comms/torchcomms/rcclx/tests/unit/cpp/mocks/HipMock.hpp @@ -34,9 +34,11 @@ class HipMock : public HipApi { streamWaitEvent, (hipStream_t stream, hipEvent_t event, unsigned int flags), (override)); + // Note: Uses getCurrentCUDAStream because tests go through hipify which + // transforms getCurrentHIPStreamMasqueradingAsCUDA to getCurrentCUDAStream MOCK_METHOD( hipStream_t, - getCurrentHIPStreamMasqueradingAsCUDA, + getCurrentCUDAStream, (int device_index), (override)); MOCK_METHOD(hipError_t, streamSynchronize, (hipStream_t stream), (override));