Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions comms/torchcomms/rcclx/HipApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ hipError_t DefaultHipApi::streamWaitEvent(
return hipStreamWaitEvent(stream, event, flags);
}

hipStream_t DefaultHipApi::getCurrentHIPStreamMasqueradingAsCUDA(
int device_index) {
hipStream_t DefaultHipApi::getCurrentCUDAStream(int device_index) {
#ifdef HIPIFY_V2
return at::cuda::getCurrentCUDAStream(device_index).stream();
#else
Expand Down
7 changes: 4 additions & 3 deletions comms/torchcomms/rcclx/HipApi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class HipApi {
[[nodiscard]] virtual hipError_t streamDestroy(hipStream_t stream) = 0;
[[nodiscard]] virtual hipError_t
streamWaitEvent(hipStream_t stream, hipEvent_t event, unsigned int flags) = 0;
virtual hipStream_t getCurrentHIPStreamMasqueradingAsCUDA(
int device_index) = 0;
// Note: Named getCurrentCUDAStream because hipify transforms
// getCurrentHIPStreamMasqueradingAsCUDA to getCurrentCUDAStream
virtual hipStream_t getCurrentCUDAStream(int device_index) = 0;
[[nodiscard]] virtual hipError_t streamSynchronize(hipStream_t stream) = 0;
[[nodiscard]] virtual hipError_t threadExchangeStreamCaptureMode(
enum hipStreamCaptureMode* mode) = 0;
Expand Down Expand Up @@ -113,7 +114,7 @@ class DefaultHipApi : public HipApi {
hipStream_t stream,
hipEvent_t event,
unsigned int flags) override;
hipStream_t getCurrentHIPStreamMasqueradingAsCUDA(int device_index) override;
hipStream_t getCurrentCUDAStream(int device_index) override;
[[nodiscard]] hipError_t streamSynchronize(hipStream_t stream) override;
[[nodiscard]] hipError_t threadExchangeStreamCaptureMode(
enum hipStreamCaptureMode* mode) override;
Expand Down
3 changes: 1 addition & 2 deletions comms/torchcomms/rcclx/TorchCommRCCLXBootstrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ void TorchCommRCCLXBootstrap::cleanupTCPStore(ncclComm_t nccl_comm) {
// object.
store_.reset();

auto stream =
hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index());
auto stream = hip_api_->getCurrentCUDAStream(device_.index());
ncclResult_t result = rcclx_api_->allReduce(
barrier_buffer_,
barrier_buffer_,
Expand Down
4 changes: 2 additions & 2 deletions comms/torchcomms/rcclx/TorchCommRCCLXUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ hipStream_t TorchCommRCCLX::getOperationStream(bool async_op) {
if (async_op) {
// Get current PyTorch CUDA stream for this device
hipStream_t current_stream =
hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index());
hip_api_->getCurrentCUDAStream(device_.index());

// Record event on current stream and wait for it on internal stream
HIP_CHECK(
Expand All @@ -347,7 +347,7 @@ hipStream_t TorchCommRCCLX::getOperationStream(bool async_op) {
return internal_stream_;
} else {
// Use the current PyTorch CUDA stream for synchronous operations
return hip_api_->getCurrentHIPStreamMasqueradingAsCUDA(device_.index());
return hip_api_->getCurrentCUDAStream(device_.index());
}
}

Expand Down
3 changes: 1 addition & 2 deletions comms/torchcomms/rcclx/TorchWorkRCCLX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ void TorchWorkRCCLX::wait() {

// Get the current stream using the device from the comm object
hipStream_t current_stream =
comm_->getHipApi()->getCurrentHIPStreamMasqueradingAsCUDA(
comm_->device_.index());
comm_->getHipApi()->getCurrentCUDAStream(comm_->device_.index());

// Add a dependency from the work's stream to the current stream
// This makes the current stream wait for the end_event_ recorded on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TorchCommRCCLXWorkQueueTest : public ::testing::Test {
ON_CALL(*hip_mock_, streamSynchronize(_)).WillByDefault(Return(hipSuccess));
ON_CALL(*hip_mock_, streamWaitEvent(_, _, _))
.WillByDefault(Return(hipSuccess));
ON_CALL(*hip_mock_, getCurrentHIPStreamMasqueradingAsCUDA(_))
ON_CALL(*hip_mock_, getCurrentCUDAStream(_))
.WillByDefault(Return(current_stream_));
ON_CALL(*hip_mock_, getStreamPriorityRange(_, _))
.WillByDefault(DoAll(
Expand Down
4 changes: 3 additions & 1 deletion comms/torchcomms/rcclx/tests/unit/cpp/mocks/HipMock.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class HipMock : public HipApi {
streamWaitEvent,
(hipStream_t stream, hipEvent_t event, unsigned int flags),
(override));
// Note: Uses getCurrentCUDAStream because tests go through hipify which
// transforms getCurrentHIPStreamMasqueradingAsCUDA to getCurrentCUDAStream
MOCK_METHOD(
hipStream_t,
getCurrentHIPStreamMasqueradingAsCUDA,
getCurrentCUDAStream,
(int device_index),
(override));
MOCK_METHOD(hipError_t, streamSynchronize, (hipStream_t stream), (override));
Expand Down
Loading