From 7d6f78a115c5ad9aa9b64dc8b3892bb7a2edd487 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Wed, 12 Jun 2024 16:25:38 +0000 Subject: [PATCH 1/5] Async Pool and Memory Throttling --- paddle/common/flags.cc | 22 ++ .../memory/allocation/allocator_facade.cc | 8 +- .../allocation/cuda_malloc_async_allocator.cc | 341 ++++++++++++------ .../allocation/cuda_malloc_async_allocator.h | 80 ++-- .../platform/cuda_graph_with_memory_pool.cc | 9 +- paddle/fluid/platform/device/gpu/gpu_info.cc | 14 +- paddle/phi/backends/gpu/cuda/cuda_graph.cc | 29 +- paddle/phi/backends/gpu/cuda/cuda_graph.h | 27 +- .../gpu/cuda/cuda_graph_with_memory_pool.h | 6 +- 9 files changed, 388 insertions(+), 148 deletions(-) diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 253c1a266e2ddb..647bd1573c3237 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1130,6 +1130,28 @@ PHI_DEFINE_EXPORTED_bool(use_cuda_malloc_async_allocator, false, "Enable CUDAMallocAsyncAllocator"); +/* + * CUDAMallocAsyncAllocator related FLAG + * Name: FLAGS_cuda_malloc_async_pool_memory_throttle_ratio + * Since Version: 2.7 + * Value Range: double, [0.0, 1.0], default=0.8 + * Note:memory_throttle_ratio provides a threshold that determines when to + * initiate synchronization operations to deallocate memory. This mechanism + * helps in ensuring that the system does not exceed its memory capacity while + * also attempting to minimize performance degradation caused by frequent memory + * synchronization. + * + * Please see Note [cuda_malloc_async_pool_memory_throttle_ratio] + */ +PHI_DEFINE_EXPORTED_double( + cuda_malloc_async_pool_memory_throttle_ratio, + 0.8, + "memory_throttle_ratio provides a threshold that determines when to " + "initiate synchronization operations to deallocate memory. " + "This mechanism helps in ensuring that the system does not exceed its " + "memory capacity while also attempting to minimize performance degradation " + "caused by frequent memory synchronization."); + /* * CUDA Graph / Allocator related FLAG * Name: FLAGS_auto_free_cudagraph_allocations_on_launch diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 6dfec39cc43911..6a5df0ba9ca344 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -433,7 +433,7 @@ class AllocatorFacadePrivate { /* unique_lock_guard */ { std::unique_lock lock_guard( cuda_allocator_mutex_); - InitStreamSafeCUDAAllocator(place, stream); + InitCUDAAllocator(place, stream); return cuda_allocators_[place][stream]; } } @@ -865,7 +865,7 @@ class AllocatorFacadePrivate { return std::make_shared(p); } - void InitStreamSafeCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { + void InitCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { PADDLE_ENFORCE_EQ( strategy_, AllocatorStrategy::kAutoGrowth, @@ -1812,6 +1812,9 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { FLAGS_allocator_strategy)); auto& allocator = cuda_graph_map_[id]; auto& ref_cnt = cuda_graph_ref_cnt_[id]; + ++ref_cnt; + + if (FLAGS_use_cuda_malloc_async_allocator) return; if (allocator.get() == nullptr) { allocator = std::make_unique( /*allow_free_idle_chunk=*/false); @@ -1819,7 +1822,6 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { } else { VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id; } - ++ref_cnt; } void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc index b4fc316df3de89..c227b08e9ba502 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc @@ -13,7 +13,10 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h" +#include #include +#include +#include "paddle/common/macros.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" @@ -33,84 +36,107 @@ #include "paddle/phi/backends/gpu/rocm/hip_graph.h" #endif +#include "paddle/utils/optional.h" + +PHI_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); + namespace paddle::memory::allocation { +/* + * Note: [cuda_malloc_async_pool_memory_throttle_ratio] + * The primary purpose of the memory_throttle_ratio is to provide a + * threshold that determines when to initiate synchronization operations to + * deallocate memory. This mechanism helps in ensuring that the system does + * not exceed its memory capacity while also attempting to minimize performance + * degradation caused by frequent memory synchronization. + * + * ``` + * utilization = (allocated_size + pending_release_size) / total_memory_size + * if(utilization > memory_throttle_ratio) + * sync(free_stream, malloc_stream) + * ``` + * + * When the utilization exceeds the memory_throttle_ratio, we + * initiate a stream synchronization operation before malloc. + * + * During synchronization, all memory deallocation requests in the free queue + * are processed, effectively lowering the memory utilization before + * any new memory allocation operations are going to proceed. + * + * [Impact on Performance and Memory Usage] + * + * - Lower memory_throttle_ratio Values + * the synchronization operation will be triggered more frequently. + * This can lead to better memory utilization but might result in decreased + * performance due to the increased number of synchronization operations. + * + * - Higher memory_throttle_ratio Values + * Conversely, setting a higher value allows for more memory to be allocated + * before triggering synchronization, which can enhance performance by reducing + * the number of sync operations. However, this increases the risk of reaching + * an OOM condition since more memory can be allocated without + * immediate deallocation. + */ + thread_local std::once_flag CUDAMallocAsyncAllocation::once_flag_; -void CUDAMallocAsyncAllocation::RecordGraphCapturingStreams() { - for (gpuStream_t stream : graph_capturing_stream_set_) { - RecordStreamWithNoGraphCapturing(stream); - } - graph_capturing_stream_set_.clear(); +inline void sync_streams(gpuStream_t to_record, gpuStream_t to_wait) { + cudaEvent_t event = nullptr; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, to_record)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(to_wait, event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); } -void CUDAMallocAsyncAllocation::RecordStreamWithNoGraphCapturing( - gpuStream_t stream) { - if (event_map_.find(stream) == event_map_.end()) { - gpuEvent_t event; - PADDLE_ENFORCE_GPU_SUCCESS( - gpuEventCreateWithFlags(&event, gpuEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event, stream)); - event_map_[stream] = event; - } else { - PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event_map_[stream], stream)); - } -} +// CUDAMallocAsyncAllocation void CUDAMallocAsyncAllocation::RecordStream(gpuStream_t stream) { std::call_once(once_flag_, [this] { phi::backends::gpu::SetDeviceId(place_.device); }); - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - // Disallow recording when graph is capturing - graph_capturing_stream_set_.insert(stream); + std::lock_guard lock_guard(recorded_streams_lock_); + if (malloc_stream_ == stream) { + // Called record_stream on tensor whose original malloc_stream matches the + // recorded stream. This should have no effect. return; - } else { - RecordStreamWithNoGraphCapturing(stream); - // Record the stream after graph is captured - RecordGraphCapturingStreams(); } + recorded_streams_.insert(stream); } void CUDAMallocAsyncAllocation::EraseStream(gpuStream_t stream) { - std::lock_guard lock_guard(event_map_lock_); - event_map_.erase(stream); + std::lock_guard lock_guard(recorded_streams_lock_); + recorded_streams_.erase(stream); } -void CUDAMallocAsyncAllocation::Free(int dev_id) { - platform::RecordedGpuFreeAsync(ptr(), size(), place_.device, malloc_stream_); -} +size_t CUDAMallocAsyncAllocation::Free() { + if (recorded_streams_.empty()) { + platform::RecordedGpuFreeAsync( + ptr(), size(), place_.device, malloc_stream_); -// if synchronize, we sync the event so the block could be fully released. -bool CUDAMallocAsyncAllocation::CanBeFreed(bool synchronize) { - if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - return graph_capturing_stream_set_.empty() && event_map_.empty(); - } - // When try to free a block, we record the stream that should be record during - // capturing. - RecordGraphCapturingStreams(); + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + phi::backends::gpu::CUDAGraph::AddJoiningStreamDuringCapturing( + malloc_stream_); + } + return size(); + } else { + sync_streams(malloc_stream_, free_stream_); - std::call_once(once_flag_, - [this] { phi::backends::gpu::SetDeviceId(place_.device); }); + for (const auto& recorded_stream : recorded_streams_) { + sync_streams(recorded_stream, free_stream_); + } - for (auto it = event_map_.begin(); it != event_map_.end();) { - gpuEvent_t& event = it->second; - if (synchronize) { - PADDLE_ENFORCE_GPU_SUCCESS(gpuEventSynchronize(event)); - } else { - gpuError_t err = gpuEventQuery(event); - if (err == gpuErrorNotReady) { - VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; - return false; - } - PADDLE_ENFORCE_GPU_SUCCESS(err); + platform::RecordedGpuFreeAsync(ptr(), size(), place_.device, free_stream_); + + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + phi::backends::gpu::CUDAGraph::AddJoiningStreamDuringCapturing( + free_stream_); } - PADDLE_ENFORCE_GPU_SUCCESS(gpuEventDestroy(event)); - VLOG(8) << "Destroy event " << event; - it = event_map_.erase(it); + return 0; } - return true; } +// CUDAMallocAsyncAllocator + CUDAMallocAsyncAllocator::CUDAMallocAsyncAllocator( std::shared_ptr underlying_allocator, const phi::GPUPlace& place, @@ -118,40 +144,18 @@ CUDAMallocAsyncAllocator::CUDAMallocAsyncAllocator( : underlying_allocator_(std::move(underlying_allocator)), place_(place), default_stream_(default_stream), - memory_stream_(nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - gpuStreamCreateWithPriority(&memory_stream_, gpuStreamNonBlocking, 0)); -} - -bool CUDAMallocAsyncAllocator::IsAllocThreadSafe() const { return true; } - -void CUDAMallocAsyncAllocator::ProcessUnfreedAllocations(bool synchronize) { - if (unfreed_allocations_.empty()) { - return; - } - - std::lock_guard lock_guard(unfreed_allocation_lock_); - for (auto it = unfreed_allocations_.begin(); - it != unfreed_allocations_.end();) { - CUDAMallocAsyncAllocation* allocation = (*it); - if (allocation->CanBeFreed(synchronize)) { - allocation->Free(place_.device); - delete allocation; - it = unfreed_allocations_.erase(it); - } else { - ++it; - } - } -} - -void CUDAMallocAsyncAllocator::TryFree(CUDAMallocAsyncAllocation* allocation) { - if (allocation->CanBeFreed()) { - allocation->Free(place_.device); - delete allocation; - } else { - std::lock_guard lock_guard(unfreed_allocation_lock_); - unfreed_allocations_.emplace_back(allocation); - } + current_allocated_size_(0), + pending_release_size_(0), + memory_throttle_ratio_( + FLAGS_cuda_malloc_async_pool_memory_throttle_ratio) { + // CUDA operations are not allowed here. The cuInit function must be called + // after a new fork, and since this constructor is typically initialized + // before cuInit, we should avoid calling any CUDA API here. + phi::backends::gpu::CUDAGraph::AddPreCaptureCallback([&]() { + VLOG(0) << "[Before capture callback] " << (this) << " " + << std::this_thread::get_id(); + this->ClearFreeStream(true); + }); } uint64_t CUDAMallocAsyncAllocator::ReleaseImpl(const platform::Place& place) { @@ -162,20 +166,85 @@ uint64_t CUDAMallocAsyncAllocator::ReleaseImpl(const platform::Place& place) { uint64_t released_size = 0; // we synchronize the event so all the block could be release. - ProcessUnfreedAllocations(true); if (underlying_allocator_) released_size += underlying_allocator_->Release(place_); VLOG(8) << "Release " << released_size << " bytes memory from all streams"; return released_size; } +void CUDAMallocAsyncAllocator::ClearFreeStream(bool sync) { + LazyInitializeCudaFreeStream(); + + if (sync) { + VLOG(0) << "[CUDAMallocAsyncAllocator] " << (this) + << " synchronize the free stream to ensure all unrelesed blocks " + << "are freed"; + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(free_stream_)); + } else { + sync_streams(free_stream_, default_stream_); + } + current_allocated_size_ -= pending_release_size_; + pending_release_size_ = 0; +} + +void CUDAMallocAsyncAllocator::MallocThrottling() { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + // we disable MallocThrottling when capturing + return; + } + double allocated = + static_cast(current_allocated_size_ + pending_release_size_); + double utilization = allocated / static_cast(max_size_); + + if (utilization > memory_throttle_ratio_) { + VLOG(10) << "utilization_ratio " << utilization + << " current_allocated_size " + << string::HumanReadableSize(current_allocated_size_) + << " pending_release_size " + << string::HumanReadableSize(pending_release_size_); + CUDAMallocAsyncAllocator::ClearFreeStream(); + } +} + +void CUDAMallocAsyncAllocator::FreeAllocation( + CUDAMallocAsyncAllocation* allocation) { + auto current_released_size = allocation->Free(); + current_allocated_size_ -= current_released_size; + // The amount of pending release size (the space that has been queued to + // free_stream, that are going to be freed in the future) + pending_release_size_ += (allocation->size() - current_released_size); +} + +/* + * There are four distinct scenarios involving `cudaMalloc`, `cudaFree`, and + * `cudaGraph`: + * + * 1. When both `cudaMalloc` and `cudaFree` occur within a graph. + * 2. When `cudaMalloc` happens within a graph, but `cudaFree` occurs outside + * the graph. + * 3. When `cudaMalloc` takes place outside a graph, but `cudaFree` happens + * within a graph. + * 4. When both `cudaMalloc` and `cudaFree` are executed outside any graph. + * + * For cases (1.) and (4.), the usage aligns with the typical pattern of + * `cudaMalloc`/`cudaFree`. + * + * In case (1.), `FreeImpl` removes the allocation from + * `graph_owned_allocations_`, followed by `FreeAllocation`. In case (2.), the + * callback within `AllocateImpl` would free the allocation after the graph is + * destroyed. In case (3.), `FreeImpl` releases the allocation after the CUDA + * graph has completed its capture. Finally, in case (4.), `FreeImpl` would call + * `FreeAllocation`, and the allocation would be freed. + */ + void CUDAMallocAsyncAllocator::FreeImpl(phi::Allocation* phi_allocation) { auto* allocation = dynamic_cast(phi_allocation); + std::lock_guard lock_guard(graph_owned_allocations_lock_); - // VLOG(0) << "Free " << allocation->ptr(); // During graph capturing, only free the memory blocks owned by the graph; // others are cached. - if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + // Handles scenario (3.) if (graph_owned_allocations_.find(allocation) == graph_owned_allocations_.end()) { // If the block is not owned by the graph, cache it for release after @@ -184,37 +253,101 @@ void CUDAMallocAsyncAllocator::FreeImpl(phi::Allocation* phi_allocation) { [=]() { // Release this block after capturing VLOG(0) << "[PostCaptureCallback] Releasing ptr = " - << allocation->ptr() << " size = " << allocation->size(); - TryFree(allocation); + << allocation->ptr() << " size = " + << string::HumanReadableSize(allocation->size()); + FreeAllocation(allocation); }); return; } - } - - // If not capturing or if the block is graph-owned, free it immediately. - if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + // Handles scenario (1.) graph_owned_allocations_.erase(allocation); + } else { + // Handles scenario (2.) + if (graph_owned_allocations_.find(allocation) != + graph_owned_allocations_.end()) { + auto graph = graph_owned_allocations_[allocation]; + VLOG(0) << "[Rescheduled cudaFreeAsync] Allocation ptr = " + << allocation->ptr() + << " is allocated in a graph but freed outside the graph." + << " The allocation is rescheduled to be freed after the " + << "destruction of graph " << graph; + graph_owned_allocations_.erase(allocation); + + // No need to free the allocation + return; + } } - TryFree(allocation); + + // Handles scenario (1.) and (4.) + FreeAllocation(allocation); +} + +void CUDAMallocAsyncAllocator::LazyInitializeCudaFreeStream() { + std::call_once(once_flag_, [this] { + size_t avail, total, actual_avail, actual_total; + platform::RecordedGpuMemGetInfo( + &avail, &total, &actual_avail, &actual_total, place_.device); + max_size_ = total; + + VLOG(0) << "[CUDAMallocAsyncAllocator] " << (this) << " place " << place_ + << " max_size " << string::HumanReadableSize(max_size_) + << " memory_throttle_ratio " << memory_throttle_ratio_ + << " tid = " << std::this_thread::get_id(); + + PADDLE_ENFORCE_GPU_SUCCESS( + cudaStreamCreateWithPriority(&free_stream_, cudaStreamNonBlocking, 0)); + cudaDeviceGetDefaultMemPool(&mempool_, place_.device); + + platform::SetDeviceId(place_.device); + }); } phi::Allocation* CUDAMallocAsyncAllocator::AllocateImpl(size_t size) { - std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); }); - ProcessUnfreedAllocations(); + LazyInitializeCudaFreeStream(); + + MallocThrottling(); void* ptr; auto result = platform::RecordedGpuMallocAsync( &ptr, size, place_.device, default_stream_); if (LIKELY(result == gpuSuccess)) { auto* allocation = new CUDAMallocAsyncAllocation( - ptr, size, platform::Place(place_), default_stream_); - + ptr, size, platform::Place(place_), default_stream_, free_stream_); + VLOG(10) << "Allocate " << allocation->ptr() << " with allocator " + << (this); // If capturing, associate allocation with the current graph. if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { - // auto graph_id = phi::backends::gpu::CUDAGraph::CapturingPoolID(); - graph_owned_allocations_.insert(allocation); + std::lock_guard lock_guard(graph_owned_allocations_lock_); + auto capturing_graph = phi::backends::gpu::CUDAGraph::CapturingID(); + graph_owned_allocations_[allocation] = capturing_graph; + + // Handles scenario (2.) + phi::backends::gpu::CUDAGraph::AddPostResetCallbackDuringCapturing( + [=](paddle::optional graph) { + std::lock_guard lock_guard_free( + graph_owned_allocations_lock_); + + // Returns if the allocation is freed during capture. + if (graph_owned_allocations_.find(allocation) == + graph_owned_allocations_.end()) + return; + + bool replayed = graph.get().IsReplayed(); + if (replayed) { + VLOG(0) << "[Rescheduled cudaFreeAsync] Graph " << capturing_graph + << " is destructed. Allocation = " << allocation->ptr() + << " is freed."; + FreeAllocation(allocation); + } else { + VLOG(0) << "[Rescheduled cudaFreeAsync] Graph " << capturing_graph + << " is destructed without any replay. Allocation = " + << allocation->ptr() + << " is not initialized and would not be freed."; + } + }); } + current_allocated_size_ += size; return allocation; } diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h index 25ae560a8bac98..b4606f061169d0 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h @@ -18,7 +18,9 @@ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/spin_lock.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" namespace paddle { namespace memory { @@ -39,32 +41,25 @@ class CUDAMallocAsyncAllocation : public Allocation { CUDAMallocAsyncAllocation(void* ptr, size_t size, phi::Place place, - gpuStream_t stream) + gpuStream_t malloc_stream, + gpuStream_t free_stream) : Allocation(ptr, size, place), - malloc_stream_(stream), - used_in_another_stream(false) {} + malloc_stream_(malloc_stream), + free_stream_(free_stream) {} gpuStream_t GetOwningStream() const { return malloc_stream_; } - // TODO(eee4017): The current implementation of RecordStream is - // similar to that in StreamSafeCUDAAllocator. This approach might lead to - // host execution blocking and redundant EventQuery checks. Considering - // cudaMallocFree, stream-ordered semantics could be leveraged for more - // efficient device-side release. void RecordStream(gpuStream_t stream); - void RecordGraphCapturingStreams(); - void RecordStreamWithNoGraphCapturing(gpuStream_t stream); void EraseStream(gpuStream_t stream); - bool CanBeFreed(bool synchronize = false); - void Free(int dev_id); + size_t Free(); private: static thread_local std::once_flag once_flag_; gpuStream_t malloc_stream_; - bool used_in_another_stream; - std::set graph_capturing_stream_set_; - SpinLock event_map_lock_; - std::map event_map_; + gpuStream_t free_stream_; + + SpinLock recorded_streams_lock_; + std::unordered_set recorded_streams_; }; // The `CUDAMallocAsyncAllocator` class extends `Allocator` and is specialized @@ -77,9 +72,15 @@ class CUDAMallocAsyncAllocator : public Allocator { const phi::GPUPlace& place, gpuStream_t default_stream); - bool IsAllocThreadSafe() const override; + bool IsAllocThreadSafe() const override { return true; } gpuStream_t GetDefaultStream() const; void SetDefaultStream(gpuStream_t stream); + void ClearFreeStream(bool sync = false); + + ~CUDAMallocAsyncAllocator() { + VLOG(0) << "Async allocator is freed " << (this) + << " tid = " << std::this_thread::get_id(); + } protected: void FreeImpl(phi::Allocation* allocation) override; @@ -87,19 +88,46 @@ class CUDAMallocAsyncAllocator : public Allocator { uint64_t ReleaseImpl(const platform::Place& place) override; private: - void ProcessUnfreedAllocations(bool synchronize = false); - void TryFree(CUDAMallocAsyncAllocation* allocation); + void LazyInitializeCudaFreeStream(); + void MallocThrottling(); + void FreeAllocation(CUDAMallocAsyncAllocation* allocation); std::shared_ptr underlying_allocator_; - phi::GPUPlace place_; // Specifies the CUDA device context. + phi::GPUPlace place_; // Specifies the CUDA device context. + + cudaMemPool_t mempool_; gpuStream_t default_stream_; // Default stream for memory operations. - // TODO(eee4017): We may use a single stream to malloc/free to prevent host - // blocking - gpuStream_t memory_stream_; + + // we create a `free stream` for each allocator (each device should have a + // unique allocator) if an allocation is recorded on other stream than default + // stream, we release the allocation on `free stream` + gpuStream_t free_stream_; + + size_t current_allocated_size_; + size_t pending_release_size_; + size_t max_size_; + + double memory_throttle_ratio_; + std::once_flag once_flag_; - std::unordered_set graph_owned_allocations_; - std::list unfreed_allocations_; - SpinLock unfreed_allocation_lock_; + + /* + * Life cycle management of graph_owned_allocations_: + * + * Each element within `graph_owned_allocations_` is initialized at + * `AllocateImpl`. However, there are two distinct ways of deconstruction. + * + * (A.) Deallocating occurs within `FreeImpl`. + * This implies that the allocation is initialized and disposed of during a + * graph capture, as in scenario (1.) + * + * (B.) Deallocation takes place in the callback after the graph is + * destructed. Meaning, the allocation is initialized during a graph capture + * but disposed of outside that context, as in scenario (2.) + */ + std::unordered_map + graph_owned_allocations_; + SpinLock graph_owned_allocations_lock_; }; } // namespace allocation diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index aa245055d5e4f5..e137b60b15944b 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -143,10 +143,11 @@ void BeginCUDAGraphCapture(phi::GPUPlace place, << " wait for cuda graph dev_ctx: " << dev_ctx; } } - AddPostResetCallbackIfCapturingCUDAGraph([pool_id] { - memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( - pool_id); - }); + AddPostResetCallbackIfCapturingCUDAGraph( + [=](paddle::optional graph) { + memory::allocation::AllocatorFacade::Instance() + .RemoveMemoryPoolOfCUDAGraph(pool_id); + }); } std::unique_ptr EndCUDAGraphCapture() { diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index d84b2fa411d569..a0e8ef5c0979c1 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -240,6 +240,7 @@ class RecordedGpuMallocHelper { size, platform::TracerMemEventType::ReservedAllocate); #ifdef PADDLE_WITH_TESTING + std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.insert(*ptr); #endif @@ -308,6 +309,7 @@ class RecordedGpuMallocHelper { size, platform::TracerMemEventType::ReservedAllocate); #ifdef PADDLE_WITH_TESTING + std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.insert(*ptr); #endif @@ -358,6 +360,7 @@ class RecordedGpuMallocHelper { // hipErrorDeinitialized } #ifdef PADDLE_WITH_TESTING + std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.erase(ptr); #endif } @@ -393,6 +396,7 @@ class RecordedGpuMallocHelper { // hipErrorDeinitialized } #ifdef PADDLE_WITH_TESTING + std::lock_guard lock_guard(gpu_ptrs_mutex); gpu_ptrs.erase(ptr); #endif @@ -403,6 +407,7 @@ class RecordedGpuMallocHelper { } void *GetBasePtr(void *ptr) { #ifdef PADDLE_WITH_TESTING + std::lock_guard lock_guard(gpu_ptrs_mutex); auto it = gpu_ptrs.upper_bound(ptr); if (it == gpu_ptrs.begin()) { return nullptr; @@ -434,7 +439,7 @@ class RecordedGpuMallocHelper { } if (NeedRecord()) { - std::lock_guard guard(*mtx_); + std::lock_guard lock_guard(*mtx_); *avail = std::min(*actual_avail, limit_size_ - cur_size_.load()); *total = std::min(*actual_total, limit_size_); return *total < *actual_total; @@ -513,8 +518,11 @@ class RecordedGpuMallocHelper { mutable std::unique_ptr mtx_; static std::once_flag once_flag_; - std::set gpu_ptrs; // just for testing -}; // NOLINT + + // just for testing + std::set gpu_ptrs; + std::mutex gpu_ptrs_mutex; +}; // NOLINT std::once_flag RecordedGpuMallocHelper::once_flag_; diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index ced9c22816c637..eb4bf6d5e5004e 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -29,6 +29,7 @@ namespace phi::backends::gpu { std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; +std::vector> CUDAGraph::cudagraph_pre_capture_callbacks_; static std::vector ToposortCUDAGraph(cudaGraph_t graph) { size_t num_nodes; @@ -111,13 +112,14 @@ void CUDAGraph::Reset() { for (auto iter = cudagraph_post_reset_callbacks_.rbegin(); iter != cudagraph_post_reset_callbacks_.rend(); ++iter) { - (*iter)(); + (*iter)(*this); } cudagraph_post_reset_callbacks_.clear(); is_reset_ = true; } void CUDAGraph::Replay() { + is_replayed_ = true; #if CUDA_VERSION >= 10010 PADDLE_ENFORCE_EQ(is_reset_, false, @@ -152,6 +154,11 @@ void CUDAGraph::BeginSegmentCapture() { "you cannot begin segmented capturing in the thread " "which is not the one that starts the capturing.")); } + + for (auto &hook : cudagraph_pre_capture_callbacks_) { + hook(); + } + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture( capturing_graph_->stream_, capturing_graph_->capture_mode_)); PADDLE_ENFORCE_EQ( @@ -190,6 +197,16 @@ void CUDAGraph::BeginCapture(phi::GPUPlace place, #endif } +inline void sync_streams(gpuStream_t to_record, gpuStream_t to_wait) { + if (to_record == to_wait) return; + cudaEvent_t event = nullptr; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, to_record)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(to_wait, event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); +} + void CUDAGraph::EndSegmentCapture() { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 @@ -197,6 +214,14 @@ void CUDAGraph::EndSegmentCapture() { IsCapturing(), true, phi::errors::PermissionDenied("No CUDA Graph is capturing.")); + + for (const auto &stream : capturing_graph_->streams_to_join_) { + VLOG(10) << "Joining steam when the capture is going to end stream =" + << stream; + sync_streams(stream, capturing_graph_->stream_); + } + capturing_graph_->streams_to_join_.clear(); + cudaGraph_t graph; PADDLE_ENFORCE_GPU_SUCCESS( cudaStreamEndCapture(capturing_graph_->stream_, &graph)); @@ -219,8 +244,6 @@ void CUDAGraph::EndSegmentCapture() { capturing_graph_->cudagraph_pre_replay_callbacks_.emplace_back( CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph)); - // if forward graph is registered, this graph is a backward graph - // we check whether there is remain blocks that is unreleased by this cudaGraphExec_t exec_graph; if (FLAGS_use_cuda_malloc_async_allocator && FLAGS_auto_free_cudagraph_allocations_on_launch) { diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index dfc981850ca130..64a3344d867601 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -248,18 +248,29 @@ class CUDAGraph { void Reset(); - void AddPostResetCallback(std::function callback) { + void AddPostResetCallback( + std::function)> callback) { std::lock_guard guard(mtx_); cudagraph_post_reset_callbacks_.push_back(std::move(callback)); } + static void AddPreCaptureCallback(std::function callback) { + cudagraph_pre_capture_callbacks_.push_back(std::move(callback)); + } + void AddPostCaptureCallback(std::function callback) { std::lock_guard guard(mtx_); cudagraph_post_capture_callbacks_.push_back(std::move(callback)); } + void AddJoiningStream(cudaStream_t stream) { + streams_to_join_.insert(stream); + } + void PrintToDotFiles(const std::string &dirname, unsigned int flags); + bool IsReplayed() const { return is_replayed_; } + static void BeginCapture(phi::GPUPlace place, cudaStream_t stream, gpuStreamCaptureMode mode); @@ -268,8 +279,12 @@ class CUDAGraph { static void BeginSegmentCapture(); static void EndSegmentCapture(); + static void AddJoiningStreamDuringCapturing(cudaStream_t stream) { + capturing_graph_->AddJoiningStream(stream); + } + static void AddPostResetCallbackDuringCapturing( - std::function callback) { + std::function)> callback) { capturing_graph_->AddPostResetCallback(std::move(callback)); } @@ -331,14 +346,20 @@ class CUDAGraph { CUDAGraphID id_; int64_t pool_id_{kInvalidPoolID}; bool is_reset_{false}; + bool is_replayed_{false}; std::mutex mtx_; std::vector set_seed_funcs_; + std::unordered_set streams_to_join_; + // Holds callbacks that are triggered after the CUDA graph is reset. These // callbacks are used for operations that need to be performed following the // reset of a CUDA graph. - std::vector> cudagraph_post_reset_callbacks_; + std::vector)>> + cudagraph_post_reset_callbacks_; + + static std::vector> cudagraph_pre_capture_callbacks_; // Contains callbacks that are invoked after the CUDA graph has been captured. // These callbacks are crucial for managing memory allocations related to the diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h index 2d5810fbe1c9b6..2d72cc6a35d0ef 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h @@ -49,7 +49,7 @@ inline void AddPostResetCallbackIfCapturingCUDAGraph(Callback &&callback) { std::forward(callback)); } #endif - callback(); + callback({}); } template @@ -62,7 +62,9 @@ inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { void *new_host_mem = new uint8_t[nbytes]; std::memcpy(new_host_mem, host_mem, nbytes); AddPostResetCallbackIfCapturingCUDAGraph( - [new_host_mem] { delete[] reinterpret_cast(new_host_mem); }); + [=](paddle::optional graph) { + delete[] reinterpret_cast(new_host_mem); + }); return reinterpret_cast(new_host_mem); } #endif From 360e3828f880fa8edf2917ba8f4b6e8382480771 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Wed, 26 Jun 2024 15:01:22 +0000 Subject: [PATCH 2/5] fix rocm build --- paddle/phi/backends/gpu/rocm/hip_graph.cc | 2 +- paddle/phi/backends/gpu/rocm/hip_graph.h | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.cc b/paddle/phi/backends/gpu/rocm/hip_graph.cc index 781cb41ae69833..b0937eda084034 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.cc +++ b/paddle/phi/backends/gpu/rocm/hip_graph.cc @@ -106,7 +106,7 @@ void CUDAGraph::Reset() { for (auto iter = cudagraph_post_reset_callbacks_.rbegin(); iter != cudagraph_post_reset_callbacks_.rend(); ++iter) { - (*iter)(); + (*iter)(*this); } cudagraph_post_reset_callbacks_.clear(); is_reset_ = true; diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.h b/paddle/phi/backends/gpu/rocm/hip_graph.h index cb922752272543..a7d4d21cf4cfc7 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.h +++ b/paddle/phi/backends/gpu/rocm/hip_graph.h @@ -239,7 +239,8 @@ class CUDAGraph { void Reset(); - void AddPostResetCallback(std::function callback) { + void AddPostResetCallback( + std::function)> callback) { std::lock_guard guard(mtx_); cudagraph_post_reset_callbacks_.push_back(std::move(callback)); } @@ -260,7 +261,7 @@ class CUDAGraph { static void EndSegmentCapture(); static void AddPostResetCallbackDuringCapturing( - std::function callback) { + std::function)> callback) { capturing_graph_->AddPostResetCallback(std::move(callback)); } @@ -329,7 +330,8 @@ class CUDAGraph { // Holds callbacks that are triggered after the CUDA graph is reset. These // callbacks are used for operations that need to be performed following the // reset of a CUDA graph. - std::vector> cudagraph_post_reset_callbacks_; + std::vector)>> + cudagraph_post_reset_callbacks_; // Contains callbacks that are invoked after the CUDA graph has been captured. // These callbacks are crucial for managing memory allocations related to the From 2500026368b6b5d76195393801070d29bca43dbd Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Mon, 1 Jul 2024 05:54:52 +0000 Subject: [PATCH 3/5] fix flag --- .../memory/allocation/cuda_malloc_async_allocator.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc index c227b08e9ba502..d99cf28bc025d2 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc @@ -16,6 +16,7 @@ #include #include #include +#include "paddle/common/flags.h" #include "paddle/common/macros.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" @@ -38,10 +39,6 @@ #include "paddle/utils/optional.h" -PHI_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); - -namespace paddle::memory::allocation { - /* * Note: [cuda_malloc_async_pool_memory_throttle_ratio] * The primary purpose of the memory_throttle_ratio is to provide a @@ -77,6 +74,9 @@ namespace paddle::memory::allocation { * an OOM condition since more memory can be allocated without * immediate deallocation. */ +PHI_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); + +namespace paddle::memory::allocation { thread_local std::once_flag CUDAMallocAsyncAllocation::once_flag_; From b3d984020edef69644182db4d0d7b2f1f896887c Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Mon, 1 Jul 2024 09:27:22 +0000 Subject: [PATCH 4/5] fix rocm build --- .../memory/allocation/allocator_facade.cc | 74 +++++++++++++------ .../allocation/cuda_malloc_async_allocator.cc | 12 +-- .../allocation/cuda_malloc_async_allocator.h | 7 ++ paddle/phi/backends/gpu/cuda/cuda_graph.cc | 4 + paddle/phi/backends/gpu/cuda/cuda_graph.h | 4 + paddle/phi/backends/gpu/rocm/hip_graph.cc | 4 + paddle/phi/backends/gpu/rocm/hip_graph.h | 4 + 7 files changed, 79 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 6a5df0ba9ca344..f8624f17928480 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -200,8 +200,10 @@ class AllocatorFacadePrivate { : #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) default_stream_safe_cuda_allocators_(), - default_cuda_malloc_async_allocators_(), cuda_allocators_(), +#endif +#ifdef PADDLE_WITH_CUDA + default_cuda_malloc_async_allocators_(), #endif allocators_() { strategy_ = GetAllocatorStrategy(); @@ -443,9 +445,11 @@ class AllocatorFacadePrivate { if (auto iter = default_stream_safe_cuda_allocators_.find(place); iter != default_stream_safe_cuda_allocators_.end()) return iter->second; +#ifdef PADDLE_WITH_CUDA if (auto iter = default_cuda_malloc_async_allocators_.find(place); iter != default_cuda_malloc_async_allocators_.end()) return iter->second; +#endif PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator found for the place, %s", place)); } @@ -454,10 +458,12 @@ class AllocatorFacadePrivate { if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { return allocator->GetDefaultStream(); +#ifdef PADDLE_WITH_CUDA } else if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { return allocator->GetDefaultStream(); +#endif } else { PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for " @@ -484,6 +490,7 @@ class AllocatorFacadePrivate { VLOG(8) << "Set default stream to " << stream << " for StreamSafeCUDAAllocator(" << allocator.get() << ") in " << place; +#ifdef PADDLE_WITH_CUDA } else if (auto allocator = std::dynamic_pointer_cast( GetDefaultStreamSafeCUDAAllocator(place))) { @@ -501,6 +508,7 @@ class AllocatorFacadePrivate { VLOG(8) << "Set default stream to " << stream << " for CUDAMallocAsyncAllocator(" << allocator.get() << ") in " << place; +#endif } else { PADDLE_THROW(platform::errors::NotFound( "No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for " @@ -511,13 +519,15 @@ class AllocatorFacadePrivate { void RecordStream(std::shared_ptr allocation, gpuStream_t stream) { - if (auto cuda_malloc_async_allocation = - std::dynamic_pointer_cast(allocation)) { - cuda_malloc_async_allocation->RecordStream(stream); - } else if (auto stream_safe_cuda_allocation = - std::dynamic_pointer_cast( - allocation)) { + if (auto stream_safe_cuda_allocation = + std::dynamic_pointer_cast(allocation)) { stream_safe_cuda_allocation->RecordStream(stream); +#ifdef PADDLE_WITH_CUDA + } else if (auto cuda_malloc_async_allocation = + std::dynamic_pointer_cast( + allocation)) { + cuda_malloc_async_allocation->RecordStream(stream); +#endif } else { VLOG(6) << "RecordStream for a non-StreamSafeCUDAAllocation"; } @@ -525,13 +535,15 @@ class AllocatorFacadePrivate { void EraseStream(std::shared_ptr allocation, gpuStream_t stream) { - if (auto cuda_malloc_async_allocation = - std::dynamic_pointer_cast(allocation)) { - cuda_malloc_async_allocation->EraseStream(stream); - } else if (auto stream_safe_cuda_allocation = - std::dynamic_pointer_cast( - allocation)) { + if (auto stream_safe_cuda_allocation = + std::dynamic_pointer_cast(allocation)) { stream_safe_cuda_allocation->EraseStream(stream); +#ifdef PADDLE_WITH_CUDA + } else if (auto cuda_malloc_async_allocation = + std::dynamic_pointer_cast( + allocation)) { + cuda_malloc_async_allocation->EraseStream(stream); +#endif } else { VLOG(6) << "EraseStream for a non-StreamSafeCUDAAllocation"; } @@ -539,17 +551,18 @@ class AllocatorFacadePrivate { gpuStream_t GetStream( const std::shared_ptr& allocation) const { - if (const std::shared_ptr - cuda_malloc_async_allocation = - std::dynamic_pointer_cast( + if (const std::shared_ptr + stream_safe_cuda_allocation = + std::dynamic_pointer_cast( allocation)) { - return cuda_malloc_async_allocation->GetOwningStream(); - - } else if (const std::shared_ptr - stream_safe_cuda_allocation = - std::dynamic_pointer_cast( - allocation)) { return stream_safe_cuda_allocation->GetOwningStream(); +#ifdef PADDLE_WITH_CUDA + } else if (const std::shared_ptr + cuda_malloc_async_allocation = + std::dynamic_pointer_cast( + allocation)) { + return cuda_malloc_async_allocation->GetOwningStream(); +#endif } VLOG(6) << "GetStream for a non-StreamSafeCUDAAllocation"; @@ -897,9 +910,14 @@ class AllocatorFacadePrivate { } void InitCUDAMallocAsyncAllocator(phi::GPUPlace p, gpuStream_t stream) { +#ifdef PADDLE_WITH_CUDA std::shared_ptr& allocator = cuda_allocators_[p][stream]; cuda_allocators_[p][stream] = std::make_shared(allocator, p, stream); +#else + PADDLE_THROW(platform::errors::Unavailable( + "CUDAMallocAsyncAllocator is not enabled")); +#endif } void InitAutoGrowthCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) { @@ -1169,6 +1187,7 @@ class AllocatorFacadePrivate { } void WrapCUDAMallocAsyncAllocatorForDefault() { +#ifdef PADDLE_WITH_CUDA for (auto& pair : allocators_) { auto& place = pair.first; if (platform::is_gpu_place(place)) { @@ -1188,6 +1207,10 @@ class AllocatorFacadePrivate { << ", allocator address = " << pair.second.get(); } } +#else + PADDLE_THROW(platform::errors::Unavailable( + "CUDAMallocAsyncAllocator is not enabled")); +#endif } void WrapCUDARetryAllocator(phi::GPUPlace p, @@ -1549,12 +1572,15 @@ class AllocatorFacadePrivate { // a standalone CUDA allocator to support multi-stream GC in new executor std::map> default_stream_safe_cuda_allocators_; - std::map> - default_cuda_malloc_async_allocators_; CUDAAllocatorMap cuda_allocators_; std::shared_timed_mutex cuda_allocator_mutex_; #endif +#if defined(PADDLE_WITH_CUDA) + std::map> + default_cuda_malloc_async_allocators_; +#endif + #ifdef PADDLE_WITH_XPU // a standalone XPU allocator to support multi-stream GC in new executor std::map> diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc index d99cf28bc025d2..fdcf60f07f1102 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc @@ -24,6 +24,7 @@ #ifdef PADDLE_WITH_CUDA #include #include +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif #include @@ -31,14 +32,11 @@ #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" -#if defined(PADDLE_WITH_CUDA) -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" -#elif defined(PADDLE_WITH_HIP) -#include "paddle/phi/backends/gpu/rocm/hip_graph.h" -#endif #include "paddle/utils/optional.h" +#ifdef PADDLE_WITH_CUDA + /* * Note: [cuda_malloc_async_pool_memory_throttle_ratio] * The primary purpose of the memory_throttle_ratio is to provide a @@ -74,7 +72,7 @@ * an OOM condition since more memory can be allocated without * immediate deallocation. */ -PHI_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); +COMMON_DECLARE_double(cuda_malloc_async_pool_memory_throttle_ratio); namespace paddle::memory::allocation { @@ -393,3 +391,5 @@ void CUDAMallocAsyncAllocator::SetDefaultStream(gpuStream_t stream) { } } // namespace paddle::memory::allocation + +#endif diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h index b4606f061169d0..9f87c2a3ac4a16 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h @@ -26,6 +26,11 @@ namespace paddle { namespace memory { namespace allocation { +class CUDAMallocAsyncAllocation; +class CUDAMallocAsyncAllocator; + +#ifdef PADDLE_WITH_CUDA + // TODO(eee4017): It may be beneficial to introduce an abstract class named // `StreamAllocator` in future developments. This class would serve as a central // entity for methods specifically related to stream management, such as @@ -130,6 +135,8 @@ class CUDAMallocAsyncAllocator : public Allocator { SpinLock graph_owned_allocations_lock_; }; +#endif + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index eb4bf6d5e5004e..71181263c26a51 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -15,6 +15,8 @@ #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #include "paddle/common/flags.h" +#ifdef PADDLE_WITH_CUDA + #if CUDA_VERSION < 11000 cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr, const void *symbolPtr) { @@ -401,3 +403,5 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { #endif } // namespace phi::backends::gpu + +#endif diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index 64a3344d867601..f794cf0fa65366 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -39,6 +39,8 @@ #include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" +#ifdef PADDLE_WITH_CUDA + #if CUDA_VERSION < 11000 // For CUDA versions less than 11.0, use a dummy type for cudaFunction_t. using cudaFunction_t = void *; @@ -421,3 +423,5 @@ class CUDAGraphCaptureModeGuard { } // namespace gpu } // namespace backends } // namespace phi + +#endif diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.cc b/paddle/phi/backends/gpu/rocm/hip_graph.cc index b0937eda084034..8c255c4dbdc279 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.cc +++ b/paddle/phi/backends/gpu/rocm/hip_graph.cc @@ -19,6 +19,8 @@ COMMON_DECLARE_bool(use_cuda_malloc_async_allocator); COMMON_DECLARE_bool(auto_free_cudagraph_allocations_on_launch); +#ifdef PADDLE_WITH_HIP + namespace phi { namespace backends { namespace gpu { @@ -363,3 +365,5 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(hipGraph_t graph) { } // namespace gpu } // namespace backends } // namespace phi + +#endif diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.h b/paddle/phi/backends/gpu/rocm/hip_graph.h index a7d4d21cf4cfc7..46d57598614729 100644 --- a/paddle/phi/backends/gpu/rocm/hip_graph.h +++ b/paddle/phi/backends/gpu/rocm/hip_graph.h @@ -37,6 +37,8 @@ #include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" +#ifdef PADDLE_WITH_HIP + namespace phi { namespace backends { namespace gpu { @@ -393,3 +395,5 @@ class CUDAGraphCaptureModeGuard { } // namespace gpu } // namespace backends } // namespace phi + +#endif From 5f25a646797010688a0ec08162888233076ceec2 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Fri, 12 Jul 2024 03:28:53 +0000 Subject: [PATCH 5/5] fix flag --- paddle/common/flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 647bd1573c3237..a420ca2962bc3b 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1133,7 +1133,7 @@ PHI_DEFINE_EXPORTED_bool(use_cuda_malloc_async_allocator, /* * CUDAMallocAsyncAllocator related FLAG * Name: FLAGS_cuda_malloc_async_pool_memory_throttle_ratio - * Since Version: 2.7 + * Since Version: 3.0 * Value Range: double, [0.0, 1.0], default=0.8 * Note:memory_throttle_ratio provides a threshold that determines when to * initiate synchronization operations to deallocate memory. This mechanism