Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_plugin_init",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/c/c_api_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
Expand Down Expand Up @@ -777,7 +778,9 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
if (status->status.ok()) {
tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle);
} else {
delete lib_handle;
return nullptr;
}
Expand Down
14 changes: 8 additions & 6 deletions tensorflow/c/experimental/stream_executor/stream_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,8 @@ port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
return result;
}

port::Status InitStreamExecutorPlugin(void* dso_handle) {
port::Status InitStreamExecutorPlugin(void* dso_handle, string* device_type,
string* platform_name) {
tensorflow::Env* env = tensorflow::Env::Default();

// Step 1: Load symbol for `TF_InitPlugin`
Expand All @@ -753,10 +754,12 @@ port::Status InitStreamExecutorPlugin(void* dso_handle) {

// Step 2: Call `TF_InitPlugin`
auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
return InitStreamExecutorPlugin(init_fn);
return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
}

port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
string* device_type,
string* platform_name) {
SE_PlatformRegistrationParams params{
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
Expand Down Expand Up @@ -806,16 +809,15 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) {
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));

// Register new platform
std::string platform_name = std::string(platform.name);
std::unique_ptr<stream_executor::CPlatform> cplatform(
new stream_executor::CPlatform(
std::move(platform), params.destroy_platform, std::move(platform_fns),
params.destroy_platform_fns, std::move(device_fns), std::move(se),
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));

// TODO(annarev): Add pluggable device registration here.
*device_type = std::string(platform.type);
*platform_name = std::string(platform.name);
return port::Status::OK();
}
} // namespace stream_executor
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ namespace stream_executor {
typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
TF_Status* const);

// Registers StreamExecutor platform.
port::Status InitStreamExecutorPlugin(void* dso_handle);
// Registers StreamExecutor platform. `device_type` and `platform_name` are
// output parameters.
port::Status InitStreamExecutorPlugin(void* dso_handle, string* device_type,
string* platform_name);

// Allow registering a StreamExecutor plugin using a function (used for
// testing).
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
string* device_type,
string* device_type_alias);

struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
Expand Down
24 changes: 18 additions & 6 deletions tensorflow/c/experimental/stream_executor/stream_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ TEST(StreamExecutor, SuccessfulRegistration) {
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
TF_ASSERT_OK(status);
port::StatusOr<Platform*> maybe_platform =
MultiPlatformManager::PlatformWithName("MY_DEVICE");
Expand All @@ -239,7 +241,9 @@ TEST(StreamExecutor, NameNotSet) {
params->destroy_platform_fns = destroy_platform_fns;
};

port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
}
Expand All @@ -254,7 +258,9 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) {
params->destroy_platform_fns = destroy_platform_fns;
};

port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(
status.error_message(),
Expand All @@ -271,7 +277,9 @@ TEST(StreamExecutor, InvalidNameWithSlash) {
params->destroy_platform_fns = destroy_platform_fns;
};

port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
Expand All @@ -287,7 +295,9 @@ TEST(StreamExecutor, CreateDeviceNotSet) {
params->destroy_platform_fns = destroy_platform_fns;
};

port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(status.error_message(),
"'create_device' field in SP_PlatformFns must be set.");
Expand All @@ -303,7 +313,9 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) {
params->destroy_platform_fns = destroy_platform_fns;
};

port::Status status = InitStreamExecutorPlugin(plugin_init);
string device_type, platform_name;
port::Status status =
InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
ASSERT_EQ(
status.error_message(),
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ cc_library(
deps = [
":core_cpu",
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
"//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime",
] + if_libtpu(["//tensorflow/core/tpu:tpu_runtime"]),
)

Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/common_runtime/copy_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ class CopyTensor {
}
};

private:
// Register a function for copying between two specific DeviceTypes.
// Note: This should only be called via the constructor of
// CopyTensor::Registration.
// CopyTensor::Registration or from PluggableDevice implementation.
static Status Register(DeviceType sender_device_type,
DeviceType receiver_device_type,
CopyFunction copy_function);
Expand Down
46 changes: 46 additions & 0 deletions tensorflow/core/common_runtime/device/device_id_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@ class DeviceIdUtil {
<< " id: " << platform_device_id
<< ", visible device count: " << visible_device_count;
}

// Parse `visible_device_list` into a list of platform Device ids.
static Status ParseVisibleDeviceList(
const string& visible_device_list, const int visible_device_count,
std::vector<PlatformDeviceId>* visible_device_order) {
visible_device_order->clear();

// If the user wants to remap the visible to virtual Device mapping,
// check for that here.
if (visible_device_list.empty()) {
visible_device_order->resize(visible_device_count);
// By default, visible to virtual mapping is unchanged.
std::iota(visible_device_order->begin(), visible_device_order->end(), 0);
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
for (const string& platform_device_id_str : order_str) {
int32 platform_device_id;
if (!strings::safe_strto32(platform_device_id_str,
&platform_device_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
platform_device_id_str,
"'. visible_device_list = ", visible_device_list);
}
if (platform_device_id < 0 ||
platform_device_id >= visible_device_count) {
return errors::InvalidArgument(
"'visible_device_list' listed an invalid Device id '",
platform_device_id, "' but visible device count is ",
visible_device_count);
}
visible_device_order->push_back(PlatformDeviceId(platform_device_id));
}
}

// Validate no repeats.
std::set<PlatformDeviceId> visible_device_set(visible_device_order->begin(),
visible_device_order->end());
if (visible_device_set.size() != visible_device_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
visible_device_list);
}
return Status::OK();
}
};

} // namespace tensorflow
Expand Down
58 changes: 6 additions & 52 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ limitations under the License.
#include <tuple>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device/device_event_mgr.h"
#include "tensorflow/core/common_runtime/device/device_id_utils.h"
#include "tensorflow/core/common_runtime/device_factory.h"
Expand All @@ -58,9 +57,10 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "third_party/gpus/cudnn/cudnn.h"
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/rocm.h"
#endif
Expand Down Expand Up @@ -223,7 +223,7 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface {
OpKernelContext* context_;

TF_DISALLOW_COPY_AND_ASSIGN(EigenGpuStreamDevice);
};
}; // namespace tensorflow

// This factory helps to ensure that different GPU device objects that refer to
// the same physical device and stream group id use the same stream group
Expand Down Expand Up @@ -776,53 +776,6 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
Eigen::GpuDevice device_;
};

// Parse 'visible_device_list' into a list of platform GPU ids.
Status ParseVisibleDeviceList(const string& visible_device_list,
std::vector<PlatformGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
se::Platform* gpu_manager = GPUMachineManager();

// If the user wants to remap the visible to virtual GPU mapping,
// check for that here.
if (visible_device_list.empty()) {
visible_gpu_order->resize(gpu_manager->VisibleDeviceCount());
// By default, visible to virtual mapping is unchanged.
int deviceNo = 0;
std::generate(visible_gpu_order->begin(), visible_gpu_order->end(),
[&deviceNo] { return deviceNo++; });
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
for (const string& platform_gpu_id_str : order_str) {
int32 platform_gpu_id;
if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
platform_gpu_id_str,
"'. visible_device_list = ", visible_device_list);
}
if (platform_gpu_id < 0 ||
platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
return errors::InvalidArgument(
"'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
"' but visible device count is ",
gpu_manager->VisibleDeviceCount());
}
visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
}
}

// Validate no repeats.
std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
visible_gpu_order->end());
if (visible_device_set.size() != visible_gpu_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
visible_device_list);
}
return Status::OK();
}

Status VerifyVirtualDeviceSettings(
const size_t num_gpus_to_use, const GPUOptions& gpu_options,
const std::vector<PlatformGpuId>& visible_gpu_order,
Expand Down Expand Up @@ -1166,8 +1119,9 @@ Status BaseGPUDeviceFactory::CreateDevices(
// because it treats an empty gpu_options.visible_device_list as 'all GPUs
// are visible'.
if (num_gpus_to_use > 0) {
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
&visible_gpu_order));
TF_RETURN_IF_ERROR(DeviceIdUtil::ParseVisibleDeviceList(
gpu_options.visible_device_list(), gpu_manager->VisibleDeviceCount(),
&visible_gpu_order));
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
int visible_gpu_id = visible_gpu_order[i].value();
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/common_runtime/graph_execution_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
Expand Down Expand Up @@ -395,7 +396,9 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
// TODO(ashankar): Instead of a allowlist here, perhaps we could query
// the kernel registry for _Arg and _Retval kernels instead.
if (device_type == DEVICE_CPU) return true;
if (device_type != DEVICE_GPU) return false;
if (device_type != DEVICE_GPU &&
!DeviceFactory::IsPluggableDevice(device_type))
return false;
switch (dtype) {
case DT_BFLOAT16:
case DT_BOOL:
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/core/common_runtime/memory_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <utility>

#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/node_builder.h"
Expand Down Expand Up @@ -48,7 +49,8 @@ struct EndpointEq {
static Status ProcessMemoryTypes(
const DeviceType& device_type, const Graph* g,
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
if (device_type != DEVICE_GPU) {
if (device_type != DEVICE_GPU &&
!DeviceFactory::IsPluggableDevice(DeviceTypeString(device_type))) {
// On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
return Status::OK();
}
Expand Down
Loading