Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ tf_cc_shared_object(
"//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler:profiler_impl",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/c/experimental/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ cc_library(
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//tensorflow/c:__subpackages__"],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/core/common_runtime/pluggable_device:__subpackages__",
],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,10 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));

// TODO(annarev): Add pluggable device registration here.
// TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
// available in `PluggableDeviceProcessState` once the latter is checked in.
return port::Status::OK();
}
} // namespace stream_executor
45 changes: 42 additions & 3 deletions tensorflow/core/common_runtime/pluggable_device/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ load(
"//tensorflow/core/platform:rules_cc.bzl",
"cc_library",
)
load(
"//tensorflow/core/platform:build_config_root.bzl",
"if_static",
)

package(
default_visibility = [
Expand Down Expand Up @@ -43,6 +47,7 @@ cc_library(
deps = [
":pluggable_device_bfc_allocator",
":pluggable_device_init_impl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
Expand All @@ -53,6 +58,8 @@ cc_library(
"//tensorflow/core/common_runtime/device:device_event_mgr",
"//tensorflow/core/platform:stream_executor",
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:kernel",
],
alwayslink = 1,
)
Expand All @@ -69,18 +76,50 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":pluggable_device_runtime_impl",
"//tensorflow/c/experimental/stream_executor",
],
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:bfc_allocator",
"//tensorflow/core/common_runtime:dma_helper",
"//tensorflow/core/common_runtime:local_device",
"//tensorflow/core/common_runtime:process_state",
"//tensorflow/core/common_runtime:shared_counter",
"//tensorflow/core/platform:stream_executor",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:kernel",
] + if_static([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jzhoulon and @penpornk
This MacOS test failure commit in pluggable impl seems to have caused regression on mac platforms to register pluggable device. When _pywrap_tensorflow .so loads up the plugin it goes through the device initialization fine then it hands over to libtensorflow_framework dylib to query the plugin handle using the Platform name in MultiPlatformManager . And during this part it fails with "Platform " not found. Seems like a linker issue where the plugin registry is not getting shared across the so and dylib.

Reverting this locally workarounds it. Will create a PR with proper fix.

":pluggable_device_runtime_impl",
"//tensorflow/core/common_runtime:copy_tensor",
]),
)

cc_library(
name = "pluggable_device_runtime",
hdrs = [":pluggable_device_runtime_headers"],
linkstatic = 1,
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime:core_cpu",
"//tensorflow/core/common_runtime:dma_helper",
"//tensorflow/core/common_runtime:shared_counter",
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:stream_executor",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:kernel",
] + if_static([
":pluggable_device_runtime_impl",
],
"//tensorflow/core/common_runtime:bfc_allocator",
"//tensorflow/core/common_runtime:process_state",
"//tensorflow/core/common_runtime:local_device",
]),
)

cc_library(
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.