Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2039,22 +2039,19 @@ class CustomAllreduce
RankData* get_buffer_RD(hipStream_t stream, void* input)
{
RankData* ptrs;
// During graph capture, always record the buffer unconditionally.
// Skip the input_buffer cache to ensure all ranks record the same
// number of buffers, even if their allocators reuse different addresses.
hipStreamCaptureStatus status;
HIP_CALL(hipStreamIsCapturing(stream, &status));
if(status == hipStreamCaptureStatusActive)
auto it = input_buffer.find(input);
if(it != input_buffer.end())
{
ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size();
graph_unreg_input_buffers_.push_back(input);
ptrs = it->second;
}
else
{
auto it = input_buffer.find(input);
if(it != input_buffer.end())
hipStreamCaptureStatus status;
HIP_CALL(hipStreamIsCapturing(stream, &status));
if(status == hipStreamCaptureStatusActive)
{
ptrs = it->second;
ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size();
graph_unreg_input_buffers_.push_back(input);
}
else
{
Expand All @@ -2070,23 +2067,21 @@ class CustomAllreduce
RankData* get_output_buffer_RD(hipStream_t stream, void* output)
{
RankData* ptrs;
// During graph capture, always record the buffer unconditionally.
// Skip the output_buffers_ cache to ensure all ranks record the same
// number of buffers, even if their allocators reuse different addresses.
hipStreamCaptureStatus status;
HIP_CALL(hipStreamIsCapturing(stream, &status));
if(status == hipStreamCaptureStatusActive)
auto it = output_buffers_.find(output);
if(it != output_buffers_.end())
{
ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() +
graph_unreg_output_buffers_.size();
graph_unreg_output_buffers_.push_back(output);
ptrs = it->second;
}
else
{
auto it = output_buffers_.find(output);
if(it != output_buffers_.end())
hipStreamCaptureStatus status;
HIP_CALL(hipStreamIsCapturing(stream, &status));
if(status == hipStreamCaptureStatusActive)
{
ptrs = it->second;
// For graph mode, collect output addresses
ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() +
graph_unreg_output_buffers_.size();
graph_unreg_output_buffers_.push_back(output);
}
else
{
Expand Down
Loading