diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 84f8a3709adf1..de09860fd26d5 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -117,13 +117,13 @@ endif() cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) # seperate init from device_context to avoid cycle dependencies -cc_library(init SRCS init.cc DEPS device_context custom_kernel) +cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool) # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS} place phi_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator context_pool) + ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator) if(WITH_XPU) target_link_libraries(device_context xpu_context) endif() diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7623db05452b1..18ac979b48ef3 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -36,9 +36,6 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/fluid/framework/phi_utils.h" -#include "paddle/phi/api/include/context_pool.h" - namespace paddle { namespace memory { @@ -155,9 +152,6 @@ inline void EmplaceDeviceContext( // lazy evaluation. i.e., only create device context at // first `Get` auto* dev_ctx = new DevCtx(p); - // init the phi DeviceContextPool at same time - auto& context_pool = - paddle::experimental::DeviceContextPool::Instance(); if (is_gpu_place(p)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto* cuda_ctx = dynamic_cast(dev_ctx); @@ -189,26 +183,6 @@ inline void EmplaceDeviceContext( memory::allocation::AllocatorFacade::Instance() .GetZeroAllocator(p) .get()); - // insert Context pointer into phi DeviceContextPool - // only get CPU and GPU DeviceContext now, add other DeviceContext type - // later if needed - if (platform::is_cpu_place(p)) { - context_pool.Insert( - static_cast(p), - static_cast< - const typename framework::ConvertToPhiContext::TYPE*>( - dev_ctx)); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - } else if (platform::is_gpu_place(p)) { - context_pool.Insert( - static_cast(p), - static_cast< - const typename framework::ConvertToPhiContext::TYPE*>( - dev_ctx)); -#endif - } else { - // skip other places now, do nothing - } return PtrType(dev_ctx); })); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e104170ca2495..2c5f24d28c6d6 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -916,6 +916,11 @@ class DeviceContextPool { size_t size() const { return device_contexts_.size(); } + const std::map>>& + device_contexts() const { + return device_contexts_; + } + private: static DeviceContextPool* pool; std::map>> diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index 2e267f6565413..754833a2ddab3 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -69,10 +69,8 @@ class DeviceContextPool { Get(place)); } - void Insert(const Place& place, const phi::DeviceContext* dev_ctx); - private: - DeviceContextPool() = default; + DeviceContextPool(); paddle::flat_hash_map context_map_; diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index 460a6c707db2a..d1408a88d6ff7 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -14,8 +14,7 @@ limitations under the License. */ #include "paddle/phi/api/include/context_pool.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -39,14 +38,27 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { return const_cast(Get(place)); } -void DeviceContextPool::Insert(const Place& place, - const phi::DeviceContext* dev_ctx) { - auto it = context_map_.find(place); - PADDLE_ENFORCE_EQ(it, - context_map_.end(), - phi::errors::AlreadyExists( - "The DeviceContext of %s already exists.", place)); - context_map_[place] = dev_ctx; +DeviceContextPool::DeviceContextPool() { + // We need to make sure that the correct value exists + // whenever we get the DeviceContext from DeviceContextPool + const auto& device_contexts = + paddle::platform::DeviceContextPool::Instance().device_contexts(); + for (const auto& pair : device_contexts) { + // only get CPU and GPU DeviceContext now, add other DeviceContext type + // later if needed + if (platform::is_cpu_place(pair.first) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + || + platform::is_gpu_place(pair.first)) { +#else + ) { +#endif + const phi::DeviceContext* dev_ctx = pair.second.get().get(); + VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " + << dev_ctx << "}"; + context_map_[pair.first] = dev_ctx; + } + } } } // namespace experimental