Skip to content

Commit

Permalink
refine init impl
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Mar 18, 2022
1 parent 7056814 commit 4c4e9aa
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 41 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)

# seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context custom_kernel)
cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool)

# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place phi_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator context_pool)
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator)
if(WITH_XPU)
target_link_libraries(device_context xpu_context)
endif()
Expand Down
26 changes: 0 additions & 26 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"

#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/phi/api/include/context_pool.h"

namespace paddle {
namespace memory {

Expand Down Expand Up @@ -155,9 +152,6 @@ inline void EmplaceDeviceContext(
// lazy evaluation. i.e., only create device context at
// first `Get`
auto* dev_ctx = new DevCtx(p);
// init the phi DeviceContextPool at same time
auto& context_pool =
paddle::experimental::DeviceContextPool::Instance();
if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* cuda_ctx = dynamic_cast<CUDADeviceContext*>(dev_ctx);
Expand Down Expand Up @@ -189,26 +183,6 @@ inline void EmplaceDeviceContext(
memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(p)
.get());
// insert Context pointer into phi DeviceContextPool
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(p)) {
context_pool.Insert(
static_cast<platform::Place>(p),
static_cast<
const typename framework::ConvertToPhiContext<DevCtx>::TYPE*>(
dev_ctx));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (platform::is_gpu_place(p)) {
context_pool.Insert(
static_cast<platform::Place>(p),
static_cast<
const typename framework::ConvertToPhiContext<DevCtx>::TYPE*>(
dev_ctx));
#endif
} else {
// skip other places now, do nothing
}
return PtrType(dev_ctx);
}));
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,11 @@ class DeviceContextPool {

size_t size() const { return device_contexts_.size(); }

const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
device_contexts() const {
return device_contexts_;
}

private:
static DeviceContextPool* pool;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
Expand Down
4 changes: 1 addition & 3 deletions paddle/phi/api/include/context_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ class DeviceContextPool {
Get(place));
}

void Insert(const Place& place, const phi::DeviceContext* dev_ctx);

private:
DeviceContextPool() = default;
DeviceContextPool();
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
context_map_;

Expand Down
32 changes: 22 additions & 10 deletions paddle/phi/api/lib/context_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ limitations under the License. */

#include "paddle/phi/api/include/context_pool.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
Expand All @@ -39,14 +38,27 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
return const_cast<phi::DeviceContext*>(Get(place));
}

void DeviceContextPool::Insert(const Place& place,
const phi::DeviceContext* dev_ctx) {
auto it = context_map_.find(place);
PADDLE_ENFORCE_EQ(it,
context_map_.end(),
phi::errors::AlreadyExists(
"The DeviceContext of %s already exists.", place));
context_map_[place] = dev_ctx;
DeviceContextPool::DeviceContextPool() {
// We need to make sure that the correct value exists
// whenever we get the DeviceContext from DeviceContextPool
const auto& device_contexts =
paddle::platform::DeviceContextPool::Instance().device_contexts();
for (const auto& pair : device_contexts) {
// only get CPU and GPU DeviceContext now, add other DeviceContext type
// later if needed
if (platform::is_cpu_place(pair.first)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
||
platform::is_gpu_place(pair.first)) {
#else
) {
#endif
const phi::DeviceContext* dev_ctx = pair.second.get().get();
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
<< dev_ctx << "}";
context_map_[pair.first] = dev_ctx;
}
}
}

} // namespace experimental
Expand Down

0 comments on commit 4c4e9aa

Please sign in to comment.