diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 70a9a9858d6bd0..43350f1c4995f9 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1130,28 +1130,6 @@ PHI_DEFINE_EXPORTED_bool(use_cuda_malloc_async_allocator, false, "Enable CUDAMallocAsyncAllocator"); -/* - * CUDAMallocAsyncAllocator related FLAG - * Name: FLAGS_cuda_malloc_async_pool_memory_throttle_ratio - * Since Version: 3.0 - * Value Range: double, [0.0, 1.0], default=0.8 - * Note:memory_throttle_ratio provides a threshold that determines when to - * initiate synchronization operations to deallocate memory. This mechanism - * helps in ensuring that the system does not exceed its memory capacity while - * also attempting to minimize performance degradation caused by frequent memory - * synchronization. - * - * Please see Note [cuda_malloc_async_pool_memory_throttle_ratio] - */ -PHI_DEFINE_EXPORTED_double( - cuda_malloc_async_pool_memory_throttle_ratio, - 0.8, - "memory_throttle_ratio provides a threshold that determines when to " - "initiate synchronization operations to deallocate memory. " - "This mechanism helps in ensuring that the system does not exceed its " - "memory capacity while also attempting to minimize performance degradation " - "caused by frequent memory synchronization."); - /* * CUDA Graph / Allocator related FLAG * Name: FLAGS_auto_free_cudagraph_allocations_on_launch diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index e663f71174dac9..001c5b87b19e70 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -200,10 +200,8 @@ class AllocatorFacadePrivate { : #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) default_stream_safe_cuda_allocators_(), - cuda_allocators_(), -#endif -#ifdef PADDLE_WITH_CUDA default_cuda_malloc_async_allocators_(), + cuda_allocators_(), #endif allocators_() { strategy_ = GetAllocatorStrategy(); @@ -435,7 +433,7 @@ class AllocatorFacadePrivate { /* unique_lock_guard */ { std::unique_lock lock_guard( cuda_allocator_mutex_); - InitCUDAAllocator(place, stream); + InitStreamSafeCUDAAllocator(place, stream); return cuda_allocators_[place][stream]; } } @@ -445,11 +443,9 @@ class AllocatorFacadePrivate { if (auto iter = default_stream_safe_cuda_allocators_.find(place); iter != default_stream_safe_cuda_allocators_.end()) return iter->second; -#ifdef PADDLE_WITH_CUDA if (auto iter = default_cuda_malloc_async_allocators_.find(place); iter != default_cuda_malloc_async_allocators_.end()) return iter->second; -#endif PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator found for the place, %s", place)); } @@ -458,12 +454,10 @@ class AllocatorFacadePrivate { if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { return allocator->GetDefaultStream(); -#ifdef PADDLE_WITH_CUDA } else if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { return allocator->GetDefaultStream(); -#endif } else { PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for " @@ -490,7 +484,6 @@ class AllocatorFacadePrivate { VLOG(8) << "Set default stream to " << stream << " for StreamSafeCUDAAllocator(" << allocator.get() << ") in " << place; -#ifdef PADDLE_WITH_CUDA } else if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { @@ -508,7 +501,6 @@ class AllocatorFacadePrivate { VLOG(8) << "Set default stream to " << stream << " for CUDAMallocAsyncAllocator(" << allocator.get() << ") in " << place; -#endif } else { PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for " @@ -519,15 +511,13 @@ class AllocatorFacadePrivate { void RecordStream(std::shared_ptr allocation, gpuStream_t stream) { - if (auto stream_safe_cuda_allocation = - std::dynamic_pointer_cast(allocation)) { - stream_safe_cuda_allocation->RecordStream(stream); -#ifdef PADDLE_WITH_CUDA - } else if (auto cuda_malloc_async_allocation = - std::dynamic_pointer_cast( - allocation)) { + if (auto cuda_malloc_async_allocation = + std::dynamic_pointer_cast(allocation)) { cuda_malloc_async_allocation->RecordStream(stream); -#endif + } else if (auto stream_safe_cuda_allocation = + std::dynamic_pointer_cast( + allocation)) { + stream_safe_cuda_allocation->RecordStream(stream); } else { VLOG(6) << "RecordStream for a non-StreamSafeCUDAAllocation"; } @@ -535,15 +525,13 @@ class AllocatorFacadePrivate { void EraseStream(std::shared_ptr allocation, gpuStream_t stream) { - if (auto stream_safe_cuda_allocation = - std::dynamic_pointer_cast(allocation)) { - stream_safe_cuda_allocation->EraseStream(stream); -#ifdef PADDLE_WITH_CUDA - } else if (auto cuda_malloc_async_allocation = - std::dynamic_pointer_cast( - allocation)) { + if (auto cuda_malloc_async_allocation = + std::dynamic_pointer_cast(allocation)) { cuda_malloc_async_allocation->EraseStream(stream); -#endif + } else if (auto stream_safe_cuda_allocation = + std::dynamic_pointer_cast( + allocation)) { + stream_safe_cuda_allocation->EraseStream(stream); } else { VLOG(6) << "EraseStream for a non-StreamSafeCUDAAllocation"; } @@ -551,18 +539,17 @@ class AllocatorFacadePrivate { gpuStream_t GetStream( const std::shared_ptr& allocation) const { - if (const std::shared_ptr - stream_safe_cuda_allocation = - std::dynamic_pointer_cast( + if (const std::shared_ptr + cuda_malloc_async_allocation = + std::dynamic_pointer_cast( allocation)) { - return stream_safe_cuda_allocation->GetOwningStream(); -#ifdef PADDLE_WITH_CUDA - } else if (const std::shared_ptr - cuda_malloc_async_allocation = - std::dynamic_pointer_cast( - allocation)) { return cuda_malloc_async_allocation->GetOwningStream(); -#endif + + } else if (const std::shared_ptr + stream_safe_cuda_allocation = + std::dynamic_pointer_cast( + allocation)) { + return stream_safe_cuda_allocation->GetOwningStream(); } VLOG(6) << "GetStream for a non-StreamSafeCUDAAllocation"; @@ -878,7 +865,7 @@ class AllocatorFacadePrivate { return std::make_shared(p); } - void InitCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { + void InitStreamSafeCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { PADDLE_ENFORCE_EQ( strategy_, AllocatorStrategy::kAutoGrowth, @@ -910,14 +897,9 @@ class AllocatorFacadePrivate { } void InitCUDAMallocAsyncAllocator(phi::GPUPlace p, gpuStream_t stream) { -#ifdef PADDLE_WITH_CUDA std::shared_ptr& allocator = cuda_allocators_[p][stream]; cuda_allocators_[p][stream] = std::make_shared(allocator, p, stream); -#else - PADDLE_THROW(platform::errors::Unavailable( - "CUDAMallocAsyncAllocator is not enabled")); -#endif } void InitAutoGrowthCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { @@ -1187,7 +1169,6 @@ class AllocatorFacadePrivate { } void WrapCUDAMallocAsyncAllocatorForDefault() { -#ifdef PADDLE_WITH_CUDA for (auto& pair : allocators_) { auto& place = pair.first; if (phi::is_gpu_place(place)) { @@ -1207,10 +1188,6 @@ class AllocatorFacadePrivate { << ", allocator address = " << pair.second.get(); } } -#else - PADDLE_THROW(platform::errors::Unavailable( - "CUDAMallocAsyncAllocator is not enabled")); -#endif } void WrapCUDARetryAllocator(phi::GPUPlace p, @@ -1570,15 +1547,12 @@ class AllocatorFacadePrivate { // a standalone CUDA allocator to support multi-stream GC in new executor std::map> default_stream_safe_cuda_allocators_; + std::map> + default_cuda_malloc_async_allocators_; CUDAAllocatorMap cuda_allocators_; std::shared_timed_mutex cuda_allocator_mutex_; #endif -#if defined(PADDLE_WITH_CUDA) - std::map> - default_cuda_malloc_async_allocators_; -#endif - #ifdef PADDLE_WITH_XPU // a standalone XPU allocator to support multi-stream GC in new executor std::map> @@ -1835,9 +1809,6 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { FLAGS_allocator_strategy)); auto& allocator = cuda_graph_map_[id]; auto& ref_cnt = cuda_graph_ref_cnt_[id]; - ++ref_cnt; - - if (FLAGS_use_cuda_malloc_async_allocator) return; if (allocator.get() == nullptr) { allocator = std::make_unique( /*allow_free_idle_chunk=*/false); @@ -1845,6 +1816,7 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { } else { VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id; } + ++ref_cnt; } void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc index 9517361bf84cee..2b74a261ce2bdc 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc @@ -13,18 +13,13 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h" -#include #include -#include -#include "paddle/common/flags.h" -#include "paddle/common/macros.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #ifdef PADDLE_WITH_CUDA #include #include -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif #include @@ -32,109 +27,90 @@ #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" - -#include "paddle/utils/optional.h" - -#ifdef PADDLE_WITH_CUDA - -/* - * Note: [cuda_malloc_async_pool_memory_throttle_ratio] - * The primary purpose of the memory_throttle_ratio is to provide a - * threshold that determines when to initiate synchronization operations to - * deallocate memory. This mechanism helps in ensuring that the system does - * not exceed its memory capacity while also attempting to minimize performance - * degradation caused by frequent memory synchronization. - * - * ``` - * utilization = (allocated_size + pending_release_size) / total_memory_size - * if(utilization > memory_throttle_ratio) - * sync(free_stream, malloc_stream) - * ``` - * - * When the utilization exceeds the memory_throttle_ratio, we - * initiate a stream synchronization operation before malloc. - * - * During synchronization, all memory deallocation requests in the free queue - * are processed, effectively lowering the memory utilization before - * any new memory allocation operations are going to proceed. - * - * [Impact on Performance and Memory Usage] - * - * - Lower memory_throttle_ratio Values - * the synchronization operation will be triggered more frequently. - * This can lead to better memory utilization but might result in decreased - * performance due to the increased number of synchronization operations. - * - * - Higher memory_throttle_ratio Values - * Conversely, setting a higher value allows for more memory to be allocated - * before triggering synchronization, which can enhance performance by reducing - * the number of sync operations. However, this increases the risk of reaching - * an OOM condition since more memory can be allocated without - * immediate deallocation. - */ -COMMON_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); +#if defined(PADDLE_WITH_CUDA) +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" +#endif namespace paddle::memory::allocation { thread_local std::once_flag CUDAMallocAsyncAllocation::once_flag_; -inline void sync_streams(gpuStream_t to_record, gpuStream_t to_wait) { - cudaEvent_t event = nullptr; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, to_record)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(to_wait, event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); +void CUDAMallocAsyncAllocation::RecordGraphCapturingStreams() { + for (gpuStream_t stream : graph_capturing_stream_set_) { + RecordStreamWithNoGraphCapturing(stream); + } + graph_capturing_stream_set_.clear(); } -// CUDAMallocAsyncAllocation +void CUDAMallocAsyncAllocation::RecordStreamWithNoGraphCapturing( + gpuStream_t stream) { + if (event_map_.find(stream) == event_map_.end()) { + gpuEvent_t event; + PADDLE_ENFORCE_GPU_SUCCESS( + gpuEventCreateWithFlags(&event, gpuEventDisableTiming)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event, stream)); + event_map_[stream] = event; + } else { + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event_map_[stream], stream)); + } +} void CUDAMallocAsyncAllocation::RecordStream(gpuStream_t stream) { std::call_once(once_flag_, [this] { phi::backends::gpu::SetDeviceId(place_.device); }); - std::lock_guard lock_guard(recorded_streams_lock_); - if (malloc_stream_ == stream) { - // Called record_stream on tensor whose original malloc_stream matches the - // recorded stream. This should have no effect. + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + // Disallow recording when graph is capturing + graph_capturing_stream_set_.insert(stream); return; + } else { + RecordStreamWithNoGraphCapturing(stream); + // Record the stream after graph is captured + RecordGraphCapturingStreams(); } - recorded_streams_.insert(stream); } void CUDAMallocAsyncAllocation::EraseStream(gpuStream_t stream) { - std::lock_guard lock_guard(recorded_streams_lock_); - recorded_streams_.erase(stream); + std::lock_guard lock_guard(event_map_lock_); + event_map_.erase(stream); } -size_t CUDAMallocAsyncAllocation::Free() { - if (recorded_streams_.empty()) { - platform::RecordedGpuFreeAsync( - ptr(), size(), place_.device, malloc_stream_); - - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - phi::backends::gpu::CUDAGraph::AddJoiningStreamDuringCapturing( - malloc_stream_); - } - return size(); - } else { - sync_streams(malloc_stream_, free_stream_); +void CUDAMallocAsyncAllocation::Free(int dev_id) { + platform::RecordedGpuFreeAsync(ptr(), size(), place_.device, malloc_stream_); +} - for (const auto& recorded_stream : recorded_streams_) { - sync_streams(recorded_stream, free_stream_); - } +// if synchronize, we sync the event so the block could be fully released. +bool CUDAMallocAsyncAllocation::CanBeFreed(bool synchronize) { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + return graph_capturing_stream_set_.empty() && event_map_.empty(); + } + // When try to free a block, we record the stream that should be record during + // capturing. + RecordGraphCapturingStreams(); - platform::RecordedGpuFreeAsync(ptr(), size(), place_.device, free_stream_); + std::call_once(once_flag_, + [this] { phi::backends::gpu::SetDeviceId(place_.device); }); - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - phi::backends::gpu::CUDAGraph::AddJoiningStreamDuringCapturing( - free_stream_); + for (auto it = event_map_.begin(); it != event_map_.end();) { + gpuEvent_t& event = it->second; + if (synchronize) { + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventSynchronize(event)); + } else { + gpuError_t err = gpuEventQuery(event); + if (err == gpuErrorNotReady) { + VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; + return false; + } + PADDLE_ENFORCE_GPU_SUCCESS(err); } - return 0; + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventDestroy(event)); + VLOG(8) << "Destroy event " << event; + it = event_map_.erase(it); } + return true; } -// CUDAMallocAsyncAllocator - CUDAMallocAsyncAllocator::CUDAMallocAsyncAllocator( std::shared_ptr underlying_allocator, const phi::GPUPlace& place, @@ -142,18 +118,40 @@ CUDAMallocAsyncAllocator::CUDAMallocAsyncAllocator( : underlying_allocator_(std::move(underlying_allocator)), place_(place), default_stream_(default_stream), - current_allocated_size_(0), - pending_release_size_(0), - memory_throttle_ratio_( - FLAGS_cuda_malloc_async_pool_memory_throttle_ratio) { - // CUDA operations are not allowed here. The cuInit function must be called - // after a new fork, and since this constructor is typically initialized - // before cuInit, we should avoid calling any CUDA API here. - phi::backends::gpu::CUDAGraph::AddPreCaptureCallback([&]() { - VLOG(0) << "[Before capture callback] " << (this) << " " - << std::this_thread::get_id(); - this->ClearFreeStream(true); - }); + memory_stream_(nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + gpuStreamCreateWithPriority(&memory_stream_, gpuStreamNonBlocking, 0)); +} + +bool CUDAMallocAsyncAllocator::IsAllocThreadSafe() const { return true; } + +void CUDAMallocAsyncAllocator::ProcessUnfreedAllocations(bool synchronize) { + if (unfreed_allocations_.empty()) { + return; + } + + std::lock_guard lock_guard(unfreed_allocation_lock_); + for (auto it = unfreed_allocations_.begin(); + it != unfreed_allocations_.end();) { + CUDAMallocAsyncAllocation* allocation = (*it); + if (allocation->CanBeFreed(synchronize)) { + allocation->Free(place_.device); + delete allocation; + it = unfreed_allocations_.erase(it); + } else { + ++it; + } + } +} + +void CUDAMallocAsyncAllocator::TryFree(CUDAMallocAsyncAllocation* allocation) { + if (allocation->CanBeFreed()) { + allocation->Free(place_.device); + delete allocation; + } else { + std::lock_guard lock_guard(unfreed_allocation_lock_); + unfreed_allocations_.emplace_back(allocation); + } } uint64_t CUDAMallocAsyncAllocator::ReleaseImpl(const phi::Place& place) { @@ -164,85 +162,20 @@ uint64_t CUDAMallocAsyncAllocator::ReleaseImpl(const phi::Place& place) { uint64_t released_size = 0; // we synchronize the event so all the block could be release. + ProcessUnfreedAllocations(true); if (underlying_allocator_) released_size += underlying_allocator_->Release(place_); VLOG(8) << "Release " << released_size << " bytes memory from all streams"; return released_size; } -void CUDAMallocAsyncAllocator::ClearFreeStream(bool sync) { - LazyInitializeCudaFreeStream(); - - if (sync) { - VLOG(0) << "[CUDAMallocAsyncAllocator] " << (this) - << " synchronize the free stream to ensure all unrelesed blocks " - << "are freed"; - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(free_stream_)); - } else { - sync_streams(free_stream_, default_stream_); - } - current_allocated_size_ -= pending_release_size_; - pending_release_size_ = 0; -} - -void CUDAMallocAsyncAllocator::MallocThrottling() { - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - // we disable MallocThrottling when capturing - return; - } - double allocated = - static_cast(current_allocated_size_ + pending_release_size_); - double utilization = allocated / static_cast(max_size_); - - if (utilization > memory_throttle_ratio_) { - VLOG(10) << "utilization_ratio " << utilization - << " current_allocated_size " - << string::HumanReadableSize(current_allocated_size_) - << " pending_release_size " - << string::HumanReadableSize(pending_release_size_); - CUDAMallocAsyncAllocator::ClearFreeStream(); - } -} - -void CUDAMallocAsyncAllocator::FreeAllocation( - CUDAMallocAsyncAllocation* allocation) { - auto current_released_size = allocation->Free(); - current_allocated_size_ -= current_released_size; - // The amount of pending release size (the space that has been queued to - // free_stream, that are going to be freed in the future) - pending_release_size_ += (allocation->size() - current_released_size); -} - -/* - * There are four distinct scenarios involving `cudaMalloc`, `cudaFree`, and - * `cudaGraph`: - * - * 1. When both `cudaMalloc` and `cudaFree` occur within a graph. - * 2. When `cudaMalloc` happens within a graph, but `cudaFree` occurs outside - * the graph. - * 3. When `cudaMalloc` takes place outside a graph, but `cudaFree` happens - * within a graph. - * 4. When both `cudaMalloc` and `cudaFree` are executed outside any graph. - * - * For cases (1.) and (4.), the usage aligns with the typical pattern of - * `cudaMalloc`/`cudaFree`. - * - * In case (1.), `FreeImpl` removes the allocation from - * `graph_owned_allocations_`, followed by `FreeAllocation`. In case (2.), the - * callback within `AllocateImpl` would free the allocation after the graph is - * destroyed. In case (3.), `FreeImpl` releases the allocation after the CUDA - * graph has completed its capture. Finally, in case (4.), `FreeImpl` would call - * `FreeAllocation`, and the allocation would be freed. - */ - void CUDAMallocAsyncAllocator::FreeImpl(phi::Allocation* phi_allocation) { auto* allocation = dynamic_cast(phi_allocation); - std::lock_guard lock_guard(graph_owned_allocations_lock_); + // VLOG(0) << "Free " << allocation->ptr(); // During graph capturing, only free the memory blocks owned by the graph; // others are cached. - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - // Handles scenario (3.) + if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { if (graph_owned_allocations_.find(allocation) == graph_owned_allocations_.end()) { // If the block is not owned by the graph, cache it for release after @@ -251,102 +184,37 @@ void CUDAMallocAsyncAllocator::FreeImpl(phi::Allocation* phi_allocation) { [=]() { // Release this block after capturing VLOG(0) << "[PostCaptureCallback] Releasing ptr = " - << allocation->ptr() << " size = " - << string::HumanReadableSize(allocation->size()); - FreeAllocation(allocation); + << allocation->ptr() << " size = " << allocation->size(); + TryFree(allocation); }); return; } - // Handles scenario (1.) - graph_owned_allocations_.erase(allocation); - } else { - // Handles scenario (2.) - if (graph_owned_allocations_.find(allocation) != - graph_owned_allocations_.end()) { - auto graph = graph_owned_allocations_[allocation]; - VLOG(0) << "[Rescheduled cudaFreeAsync] Allocation ptr = " - << allocation->ptr() - << " is allocated in a graph but freed outside the graph." - << " The allocation is rescheduled to be freed after the " - << "destruction of graph " << graph; - graph_owned_allocations_.erase(allocation); - - // No need to free the allocation - return; - } } - // Handles scenario (1.) and (4.) - FreeAllocation(allocation); -} - -void CUDAMallocAsyncAllocator::LazyInitializeCudaFreeStream() { - std::call_once(once_flag_, [this] { - size_t avail, total, actual_avail, actual_total; - platform::RecordedGpuMemGetInfo( - &avail, &total, &actual_avail, &actual_total, place_.device); - max_size_ = total; - - VLOG(0) << "[CUDAMallocAsyncAllocator] " << (this) << " place " << place_ - << " max_size " << string::HumanReadableSize(max_size_) - << " memory_throttle_ratio " << memory_throttle_ratio_ - << " tid = " << std::this_thread::get_id(); - - PADDLE_ENFORCE_GPU_SUCCESS( - cudaStreamCreateWithPriority(&free_stream_, cudaStreamNonBlocking, 0)); - cudaDeviceGetDefaultMemPool(&mempool_, place_.device); - - platform::SetDeviceId(place_.device); - }); + // If not capturing or if the block is graph-owned, free it immediately. + if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + graph_owned_allocations_.erase(allocation); + } + TryFree(allocation); } phi::Allocation* CUDAMallocAsyncAllocator::AllocateImpl(size_t size) { - LazyInitializeCudaFreeStream(); - - MallocThrottling(); + std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); }); + ProcessUnfreedAllocations(); void* ptr; auto result = platform::RecordedGpuMallocAsync( &ptr, size, place_.device, default_stream_); if (LIKELY(result == gpuSuccess)) { auto* allocation = new CUDAMallocAsyncAllocation( - ptr, size, phi::Place(place_), default_stream_, free_stream_); - VLOG(10) << "Allocate " << allocation->ptr() << " with allocator " - << (this); + ptr, size, phi::Place(place_), default_stream_); // If capturing, associate allocation with the current graph. if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - std::lock_guard lock_guard(graph_owned_allocations_lock_); - auto capturing_graph = phi::backends::gpu::CUDAGraph::CapturingID(); - graph_owned_allocations_[allocation] = capturing_graph; - - // Handles scenario (2.) - phi::backends::gpu::CUDAGraph::AddPostResetCallbackDuringCapturing( - [=](paddle::optional graph) { - std::lock_guard lock_guard_free( - graph_owned_allocations_lock_); - - // Returns if the allocation is freed during capture. - if (graph_owned_allocations_.find(allocation) == - graph_owned_allocations_.end()) - return; - - bool replayed = graph.get().IsReplayed(); - if (replayed) { - VLOG(0) << "[Rescheduled cudaFreeAsync] Graph " << capturing_graph - << " is destructed. Allocation = " << allocation->ptr() - << " is freed."; - FreeAllocation(allocation); - } else { - VLOG(0) << "[Rescheduled cudaFreeAsync] Graph " << capturing_graph - << " is destructed without any replay. Allocation = " - << allocation->ptr() - << " is not initialized and would not be freed."; - } - }); + // auto graph_id = phi::backends::gpu::CUDAGraph::CapturingPoolID(); + graph_owned_allocations_.insert(allocation); } - current_allocated_size_ += size; return allocation; } @@ -392,5 +260,3 @@ void CUDAMallocAsyncAllocator::SetDefaultStream(gpuStream_t stream) { } } // namespace paddle::memory::allocation - -#endif diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h index d0119fce2ef6cc..cdf717b0e416ac 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h @@ -18,19 +18,12 @@ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/spin_lock.h" -#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/place.h" -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" namespace paddle { namespace memory { namespace allocation { -class CUDAMallocAsyncAllocation; -class CUDAMallocAsyncAllocator; - -#ifdef PADDLE_WITH_CUDA - // TODO(eee4017): It may be beneficial to introduce an abstract class named // `StreamAllocator` in future developments. This class would serve as a central // entity for methods specifically related to stream management, such as @@ -46,25 +39,32 @@ class CUDAMallocAsyncAllocation : public Allocation { CUDAMallocAsyncAllocation(void* ptr, size_t size, phi::Place place, - gpuStream_t malloc_stream, - gpuStream_t free_stream) + gpuStream_t stream) : Allocation(ptr, size, place), - malloc_stream_(malloc_stream), - free_stream_(free_stream) {} + malloc_stream_(stream), + used_in_another_stream(false) {} gpuStream_t GetOwningStream() const { return malloc_stream_; } + // TODO(eee4017): The current implementation of RecordStream is + // similar to that in StreamSafeCUDAAllocator. This approach might lead to + // host execution blocking and redundant EventQuery checks. Considering + // cudaMallocFree, stream-ordered semantics could be leveraged for more + // efficient device-side release. void RecordStream(gpuStream_t stream); + void RecordGraphCapturingStreams(); + void RecordStreamWithNoGraphCapturing(gpuStream_t stream); void EraseStream(gpuStream_t stream); - size_t Free(); + bool CanBeFreed(bool synchronize = false); + void Free(int dev_id); private: static thread_local std::once_flag once_flag_; gpuStream_t malloc_stream_; - gpuStream_t free_stream_; - - SpinLock recorded_streams_lock_; - std::unordered_set recorded_streams_; + bool used_in_another_stream; + std::set graph_capturing_stream_set_; + SpinLock event_map_lock_; + std::map event_map_; }; // The `CUDAMallocAsyncAllocator` class extends `Allocator` and is specialized @@ -77,15 +77,9 @@ class CUDAMallocAsyncAllocator : public Allocator { const phi::GPUPlace& place, gpuStream_t default_stream); - bool IsAllocThreadSafe() const override { return true; } + bool IsAllocThreadSafe() const override; gpuStream_t GetDefaultStream() const; void SetDefaultStream(gpuStream_t stream); - void ClearFreeStream(bool sync = false); - - ~CUDAMallocAsyncAllocator() { - VLOG(0) << "Async allocator is freed " << (this) - << " tid = " << std::this_thread::get_id(); - } protected: void FreeImpl(phi::Allocation* allocation) override; @@ -93,50 +87,21 @@ class CUDAMallocAsyncAllocator : public Allocator { uint64_t ReleaseImpl(const phi::Place& place) override; private: - void LazyInitializeCudaFreeStream(); - void MallocThrottling(); - void FreeAllocation(CUDAMallocAsyncAllocation* allocation); + void ProcessUnfreedAllocations(bool synchronize = false); + void TryFree(CUDAMallocAsyncAllocation* allocation); std::shared_ptr underlying_allocator_; - phi::GPUPlace place_; // Specifies the CUDA device context. - - cudaMemPool_t mempool_; + phi::GPUPlace place_; // Specifies the CUDA device context. gpuStream_t default_stream_; // Default stream for memory operations. - - // we create a `free stream` for each allocator (each device should have a - // unique allocator) if an allocation is recorded on other stream than default - // stream, we release the allocation on `free stream` - gpuStream_t free_stream_; - - size_t current_allocated_size_; - size_t pending_release_size_; - size_t max_size_; - - double memory_throttle_ratio_; - + // TODO(eee4017): We may use a single stream to malloc/free to prevent host + // blocking + gpuStream_t memory_stream_; std::once_flag once_flag_; - - /* - * Life cycle management of graph_owned_allocations_: - * - * Each element within `graph_owned_allocations_` is initialized at - * `AllocateImpl`. However, there are two distinct ways of deconstruction. - * - * (A.) Deallocating occurs within `FreeImpl`. - * This implies that the allocation is initialized and disposed of during a - * graph capture, as in scenario (1.) - * - * (B.) Deallocation takes place in the callback after the graph is - * destructed. Meaning, the allocation is initialized during a graph capture - * but disposed of outside that context, as in scenario (2.) - */ - std::unordered_map - graph_owned_allocations_; - SpinLock graph_owned_allocations_lock_; + std::unordered_set graph_owned_allocations_; + std::list unfreed_allocations_; + SpinLock unfreed_allocation_lock_; }; -#endif - } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index e137b60b15944b..aa245055d5e4f5 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -143,11 +143,10 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, << " wait for cuda graph dev_ctx: " << dev_ctx; } } - AddPostResetCallbackIfCapturingCUDAGraph( - [=](paddle::optional graph) { - memory::allocation::AllocatorFacade::Instance() - .RemoveMemoryPoolOfCUDAGraph(pool_id); - }); + AddPostResetCallbackIfCapturingCUDAGraph([pool_id] { + memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( + pool_id); + }); } std::unique_ptr EndCUDAGraphCapture() { diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index a0e8ef5c0979c1..d84b2fa411d569 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -240,7 +240,6 @@ class RecordedGpuMallocHelper { size, platform::TracerMemEventType::ReservedAllocate); #ifdef PADDLE_WITH_TESTING - std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.insert(*ptr); #endif @@ -309,7 +308,6 @@ class RecordedGpuMallocHelper { size, platform::TracerMemEventType::ReservedAllocate); #ifdef PADDLE_WITH_TESTING - std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.insert(*ptr); #endif @@ -360,7 +358,6 @@ class RecordedGpuMallocHelper { // hipErrorDeinitialized } #ifdef PADDLE_WITH_TESTING - std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.erase(ptr); #endif } @@ -396,7 +393,6 @@ class RecordedGpuMallocHelper { // hipErrorDeinitialized } #ifdef PADDLE_WITH_TESTING - std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.erase(ptr); #endif @@ -407,7 +403,6 @@ class RecordedGpuMallocHelper { } void *GetBasePtr(void *ptr) { #ifdef PADDLE_WITH_TESTING - std::lock_guard lock_guard(gpu_ptrs_mutex); auto it = gpu_ptrs.upper_bound(ptr); if (it == gpu_ptrs.begin()) { return nullptr; @@ -439,7 +434,7 @@ class RecordedGpuMallocHelper { } if (NeedRecord()) { - std::lock_guard lock_guard(*mtx_); + std::lock_guard guard(*mtx_); *avail = std::min(*actual_avail, limit_size_ - cur_size_.load()); *total = std::min(*actual_total, limit_size_); return *total < *actual_total; @@ -518,11 +513,8 @@ class RecordedGpuMallocHelper { mutable std::unique_ptr mtx_; static std::once_flag once_flag_; - - // just for testing - std::set gpu_ptrs; - std::mutex gpu_ptrs_mutex; -}; // NOLINT + std::set gpu_ptrs; // just for testing +}; // NOLINT std::once_flag RecordedGpuMallocHelper::once_flag_; diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index 71181263c26a51..ced9c22816c637 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -15,8 +15,6 @@ #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #include "paddle/common/flags.h" -#ifdef PADDLE_WITH_CUDA - #if CUDA_VERSION < 11000 cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr, const void *symbolPtr) { @@ -31,7 +29,6 @@ namespace phi::backends::gpu { std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; -std::vector> CUDAGraph::cudagraph_pre_capture_callbacks_; static std::vector ToposortCUDAGraph(cudaGraph_t graph) { size_t num_nodes; @@ -114,14 +111,13 @@ void CUDAGraph::Reset() { for (auto iter = cudagraph_post_reset_callbacks_.rbegin(); iter != cudagraph_post_reset_callbacks_.rend(); ++iter) { - (*iter)(*this); + (*iter)(); } cudagraph_post_reset_callbacks_.clear(); is_reset_ = true; } void CUDAGraph::Replay() { - is_replayed_ = true; #if CUDA_VERSION >= 10010 PADDLE_ENFORCE_EQ(is_reset_, false, @@ -156,11 +152,6 @@ void CUDAGraph::BeginSegmentCapture() { "you cannot begin segmented capturing in the thread " "which is not the one that starts the capturing.")); } - - for (auto &hook : cudagraph_pre_capture_callbacks_) { - hook(); - } - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture( capturing_graph_->stream_, capturing_graph_->capture_mode_)); PADDLE_ENFORCE_EQ( @@ -199,16 +190,6 @@ void CUDAGraph::BeginCapture(phi::GPUPlace place, #endif } -inline void sync_streams(gpuStream_t to_record, gpuStream_t to_wait) { - if (to_record == to_wait) return; - cudaEvent_t event = nullptr; - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, to_record)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(to_wait, event)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); -} - void CUDAGraph::EndSegmentCapture() { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 @@ -216,14 +197,6 @@ void CUDAGraph::EndSegmentCapture() { IsCapturing(), true, phi::errors::PermissionDenied("No CUDA Graph is capturing.")); - - for (const auto &stream : capturing_graph_->streams_to_join_) { - VLOG(10) << "Joining steam when the capture is going to end stream =" - << stream; - sync_streams(stream, capturing_graph_->stream_); - } - capturing_graph_->streams_to_join_.clear(); - cudaGraph_t graph; PADDLE_ENFORCE_GPU_SUCCESS( cudaStreamEndCapture(capturing_graph_->stream_, &graph)); @@ -246,6 +219,8 @@ void CUDAGraph::EndSegmentCapture() { capturing_graph_->cudagraph_pre_replay_callbacks_.emplace_back( CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph)); + // if forward graph is registered, this graph is a backward graph + // we check whether there is remain blocks that is unreleased by this cudaGraphExec_t exec_graph; if (FLAGS_use_cuda_malloc_async_allocator && FLAGS_auto_free_cudagraph_allocations_on_launch) { @@ -403,5 +378,3 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { #endif } // namespace phi::backends::gpu - -#endif diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index f794cf0fa65366..dfc981850ca130 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -39,8 +39,6 @@ #include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" -#ifdef PADDLE_WITH_CUDA - #if CUDA_VERSION < 11000 // For CUDA versions less than 11.0, use a dummy type for cudaFunction_t. using cudaFunction_t = void *; @@ -250,29 +248,18 @@ class CUDAGraph { void Reset(); - void AddPostResetCallback( - std::function)> callback) { + void AddPostResetCallback(std::function callback) { std::lock_guard guard(mtx_); cudagraph_post_reset_callbacks_.push_back(std::move(callback)); } - static void AddPreCaptureCallback(std::function callback) { - cudagraph_pre_capture_callbacks_.push_back(std::move(callback)); - } - void AddPostCaptureCallback(std::function callback) { std::lock_guard guard(mtx_); cudagraph_post_capture_callbacks_.push_back(std::move(callback)); } - void AddJoiningStream(cudaStream_t stream) { - streams_to_join_.insert(stream); - } - void PrintToDotFiles(const std::string &dirname, unsigned int flags); - bool IsReplayed() const { return is_replayed_; } - static void BeginCapture(phi::GPUPlace place, cudaStream_t stream, gpuStreamCaptureMode mode); @@ -281,12 +268,8 @@ class CUDAGraph { static void BeginSegmentCapture(); static void EndSegmentCapture(); - static void AddJoiningStreamDuringCapturing(cudaStream_t stream) { - capturing_graph_->AddJoiningStream(stream); - } - static void AddPostResetCallbackDuringCapturing( - std::function)> callback) { + std::function callback) { capturing_graph_->AddPostResetCallback(std::move(callback)); } @@ -348,20 +331,14 @@ class CUDAGraph { CUDAGraphID id_; int64_t pool_id_{kInvalidPoolID}; bool is_reset_{false}; - bool is_replayed_{false}; std::mutex mtx_; std::vector set_seed_funcs_; - std::unordered_set streams_to_join_; - // Holds callbacks that are triggered after the CUDA graph is reset. These // callbacks are used for operations that need to be performed following the // reset of a CUDA graph. - std::vector)>> - cudagraph_post_reset_callbacks_; - - static std::vector> cudagraph_pre_capture_callbacks_; + std::vector> cudagraph_post_reset_callbacks_; // Contains callbacks that are invoked after the CUDA graph has been captured. // These callbacks are crucial for managing memory allocations related to the @@ -423,5 +400,3 @@ class CUDAGraphCaptureModeGuard { } // namespace gpu } // namespace backends } // namespace phi - -#endif diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h index 2d72cc6a35d0ef..2d5810fbe1c9b6 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h @@ -49,7 +49,7 @@ inline void AddPostResetCallbackIfCapturingCUDAGraph(Callback &&callback) { std::forward(callback)); } #endif - callback({}); + callback(); } template @@ -62,9 +62,7 @@ inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { void *new_host_mem = new uint8_t[nbytes]; std::memcpy(new_host_mem, host_mem, nbytes); AddPostResetCallbackIfCapturingCUDAGraph( - [=](paddle::optional graph) { - delete[] reinterpret_cast(new_host_mem); - }); + [new_host_mem] { delete[] reinterpret_cast(new_host_mem); }); return reinterpret_cast(new_host_mem); } #endif diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.cc b/paddle/phi/backends/gpu/rocm/hip_graph.cc index 8c255c4dbdc279..781cb41ae69833 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.cc +++ b/paddle/phi/backends/gpu/rocm/hip_graph.cc @@ -19,8 +19,6 @@ COMMON_DECLARE_bool(use_cuda_malloc_async_allocator); COMMON_DECLARE_bool(auto_free_cudagraph_allocations_on_launch); -#ifdef PADDLE_WITH_HIP - namespace phi { namespace backends { namespace gpu { @@ -108,7 +106,7 @@ void CUDAGraph::Reset() { for (auto iter = cudagraph_post_reset_callbacks_.rbegin(); iter != cudagraph_post_reset_callbacks_.rend(); ++iter) { - (*iter)(*this); + (*iter)(); } cudagraph_post_reset_callbacks_.clear(); is_reset_ = true; @@ -365,5 +363,3 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(hipGraph_t graph) { } // namespace gpu } // namespace backends } // namespace phi - -#endif diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.h b/paddle/phi/backends/gpu/rocm/hip_graph.h index 46d57598614729..cb922752272543 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.h +++ b/paddle/phi/backends/gpu/rocm/hip_graph.h @@ -37,8 +37,6 @@ #include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" -#ifdef PADDLE_WITH_HIP - namespace phi { namespace backends { namespace gpu { @@ -241,8 +239,7 @@ class CUDAGraph { void Reset(); - void AddPostResetCallback( - std::function)> callback) { + void AddPostResetCallback(std::function callback) { std::lock_guard guard(mtx_); cudagraph_post_reset_callbacks_.push_back(std::move(callback)); } @@ -263,7 +260,7 @@ class CUDAGraph { static void EndSegmentCapture(); static void AddPostResetCallbackDuringCapturing( - std::function)> callback) { + std::function callback) { capturing_graph_->AddPostResetCallback(std::move(callback)); } @@ -332,8 +329,7 @@ class CUDAGraph { // Holds callbacks that are triggered after the CUDA graph is reset. These // callbacks are used for operations that need to be performed following the // reset of a CUDA graph. - std::vector)>> - cudagraph_post_reset_callbacks_; + std::vector> cudagraph_post_reset_callbacks_; // Contains callbacks that are invoked after the CUDA graph has been captured. // These callbacks are crucial for managing memory allocations related to the @@ -395,5 +391,3 @@ class CUDAGraphCaptureModeGuard { } // namespace gpu } // namespace backends } // namespace phi - -#endif