From b0fe6eaa47f793b0bf3879c8c2dee9e5fe218104 Mon Sep 17 00:00:00 2001 From: TennyWang1223 Date: Tue, 14 Apr 2026 17:18:39 +0800 Subject: [PATCH] Revert "fix(car): craph capture err (#2638)" This reverts commit 5759ee2943b5326dec23910135381b426ff65196. --- csrc/include/custom_all_reduce.cuh | 41 +++++++++++++----------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index e46a215af2..391c36a0ca 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -2039,22 +2039,19 @@ class CustomAllreduce RankData* get_buffer_RD(hipStream_t stream, void* input) { RankData* ptrs; - // During graph capture, always record the buffer unconditionally. - // Skip the input_buffer cache to ensure all ranks record the same - // number of buffers, even if their allocators reuse different addresses. - hipStreamCaptureStatus status; - HIP_CALL(hipStreamIsCapturing(stream, &status)); - if(status == hipStreamCaptureStatusActive) + auto it = input_buffer.find(input); + if(it != input_buffer.end()) { - ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size(); - graph_unreg_input_buffers_.push_back(input); + ptrs = it->second; } else { - auto it = input_buffer.find(input); - if(it != input_buffer.end()) + hipStreamCaptureStatus status; + HIP_CALL(hipStreamIsCapturing(stream, &status)); + if(status == hipStreamCaptureStatusActive) { - ptrs = it->second; + ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size(); + graph_unreg_input_buffers_.push_back(input); } else { @@ -2070,23 +2067,21 @@ class CustomAllreduce RankData* get_output_buffer_RD(hipStream_t stream, void* output) { RankData* ptrs; - // During graph capture, always record the buffer unconditionally. - // Skip the output_buffers_ cache to ensure all ranks record the same - // number of buffers, even if their allocators reuse different addresses. - hipStreamCaptureStatus status; - HIP_CALL(hipStreamIsCapturing(stream, &status)); - if(status == hipStreamCaptureStatusActive) + auto it = output_buffers_.find(output); + if(it != output_buffers_.end()) { - ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() + - graph_unreg_output_buffers_.size(); - graph_unreg_output_buffers_.push_back(output); + ptrs = it->second; } else { - auto it = output_buffers_.find(output); - if(it != output_buffers_.end()) + hipStreamCaptureStatus status; + HIP_CALL(hipStreamIsCapturing(stream, &status)); + if(status == hipStreamCaptureStatusActive) { - ptrs = it->second; + // For graph mode, collect output addresses + ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() + + graph_unreg_output_buffers_.size(); + graph_unreg_output_buffers_.push_back(output); } else {