diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index e46a215af2..391c36a0ca 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -2039,22 +2039,19 @@ class CustomAllreduce RankData* get_buffer_RD(hipStream_t stream, void* input) { RankData* ptrs; - // During graph capture, always record the buffer unconditionally. - // Skip the input_buffer cache to ensure all ranks record the same - // number of buffers, even if their allocators reuse different addresses. - hipStreamCaptureStatus status; - HIP_CALL(hipStreamIsCapturing(stream, &status)); - if(status == hipStreamCaptureStatusActive) + auto it = input_buffer.find(input); + if(it != input_buffer.end()) { - ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size(); - graph_unreg_input_buffers_.push_back(input); + ptrs = it->second; } else { - auto it = input_buffer.find(input); - if(it != input_buffer.end()) + hipStreamCaptureStatus status; + HIP_CALL(hipStreamIsCapturing(stream, &status)); + if(status == hipStreamCaptureStatusActive) { - ptrs = it->second; + ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size(); + graph_unreg_input_buffers_.push_back(input); } else { @@ -2070,23 +2067,21 @@ class CustomAllreduce RankData* get_output_buffer_RD(hipStream_t stream, void* output) { RankData* ptrs; - // During graph capture, always record the buffer unconditionally. - // Skip the output_buffers_ cache to ensure all ranks record the same - // number of buffers, even if their allocators reuse different addresses. - hipStreamCaptureStatus status; - HIP_CALL(hipStreamIsCapturing(stream, &status)); - if(status == hipStreamCaptureStatusActive) + auto it = output_buffers_.find(output); + if(it != output_buffers_.end()) { - ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() + - graph_unreg_output_buffers_.size(); - graph_unreg_output_buffers_.push_back(output); + ptrs = it->second; } else { - auto it = output_buffers_.find(output); - if(it != output_buffers_.end()) + hipStreamCaptureStatus status; + HIP_CALL(hipStreamIsCapturing(stream, &status)); + if(status == hipStreamCaptureStatusActive) { - ptrs = it->second; + // For graph mode, collect output addresses + ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() + + graph_unreg_output_buffers_.size(); + graph_unreg_output_buffers_.push_back(output); } else {