From c9fafed9bc8cb0238a775fd4a0680e648c06b5b6 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Fri, 18 Aug 2023 19:51:15 -0700 Subject: [PATCH] Copy tensor content through the API PiperOrigin-RevId: 558308631 --- .../c/tf_device_context_c_api_conversions.cc | 90 ++++++++++--------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_conversions.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_conversions.cc index 1a0c23b8d8a58a..905381a451bb57 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_conversions.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_conversions.cc @@ -118,11 +118,8 @@ class TfCThunkDeviceContext final : public DeviceContext { done = [params, done = std::move(done), device_tensor](const absl::Status& status) -> void { - absl::Status tensor_status = status; - if (tensor_status.ok()) { - tensor_status = TF_TensorToTensor(params->device_tensor, device_tensor); - } - done(tensor_status); + // TODO: Find a way to convert device tensor. + done(status); Destroy(params); delete params; }; @@ -207,74 +204,79 @@ class TfCThunkDeviceContext final : public DeviceContext { const TF_DeviceContext thunk_; }; +void CopyTF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { + // TODO: Convert through a lookup table for better API compatibility. + DataType dtype = static_cast(TF_TensorType(src)); + TensorShape tensor_shape; + int dim = TF_NumDims(src); + for (int i = 0; i < dim; ++i) { + tensor_shape.AddDim(TF_Dim(src, i)); + } + *dst = Tensor(dtype, tensor_shape); + + std::memcpy(dst->data(), TF_TensorData(src), TF_TensorByteSize(src)); +} + void CpuToDeviceThunk(void* context, TF_DeviceContext_CopyCPUTensorToDevice_Params* params) { DeviceContext* device_context = static_cast(context); - Tensor *cpu_tensor = new Tensor(), *device_tensor = new Tensor(); + Tensor* cpu_tensor = new Tensor(); + Tensor* device_tensor = new Tensor(); tsl::StatusCallback done = [params, device_tensor, cpu_tensor](absl::Status status) { delete cpu_tensor; absl::Status tensor_status; - params->device_tensor = TF_TensorFromTensor(*device_tensor, &tensor_status); + // TODO: find a way to convert device tensor. + // params->device_tensor = TF_TensorFromTensor(*device_tensor, + // &tensor_status); delete device_tensor; FromC(params->done)(tensor_status); }; - absl::Status tensor_status; - tensor_status = TF_TensorToTensor(params->cpu_tensor, cpu_tensor); - if (!tensor_status.ok()) { - done(tensor_status); - return; - } + CopyTF_TensorToTensor(params->cpu_tensor, cpu_tensor); bool sync_dst_compute = params->sync_dst_compute; device_context->CopyCPUTensorToDevice(cpu_tensor, /* device = */ nullptr, device_tensor, std::move(done), sync_dst_compute); } +TF_Tensor* CopyTensorToTF_Tensor(const Tensor& src) { + // TODO: Convert through a lookup table for better API compatibility. + TF_DataType dtype = static_cast(src.dtype()); + const TensorShape& shape = src.shape(); + int64_t* dims = new int64_t[shape.dims()]; + size_t len = TF_DataTypeSize(dtype); + for (int i = 0; i < shape.dims(); ++i) { + dims[i] = shape.dim_size(i); + len *= dims[i]; + } + TF_Tensor* tf_tensor = TF_AllocateTensor(dtype, dims, shape.dims(), len); + void* data = TF_TensorData(tf_tensor); + std::memcpy(data, src.data(), len); + return tf_tensor; +} + void DeviceToCpuThunk(void* context, TF_DeviceContext_CopyDeviceTensorToCPU_Params* params) { DeviceContext* device_context = static_cast(context); - Tensor *cpu_tensor = new Tensor(), *device_tensor = new Tensor(); + Tensor* cpu_tensor = new Tensor(); + Tensor* device_tensor = new Tensor(); tsl::StatusCallback done = [params, device_tensor, cpu_tensor](absl::Status status) { delete device_tensor; - absl::Status tensor_status; - params->cpu_tensor = TF_TensorFromTensor(*cpu_tensor, &tensor_status); + params->cpu_tensor = CopyTensorToTF_Tensor(*cpu_tensor); delete cpu_tensor; - FromC(params->done)(tensor_status); + FromC(params->done)(status); }; - absl::Status tensor_status; - tensor_status = TF_TensorToTensor(params->device_tensor, device_tensor); - if (!tensor_status.ok()) { - done(tensor_status); - return; - } std::string_view tensor_name(params->tensor_name, params->tensor_name_len); - device_context->CopyDeviceTensorToCPU(device_tensor, tensor_name, - /* device = */ nullptr, cpu_tensor, - std::move(done)); + // TODO: Find a way to convert device tensor. + device_context->CopyDeviceTensorToCPU( + /* device_tensor = */ nullptr, tensor_name, + /* device = */ nullptr, cpu_tensor, std::move(done)); } void SameDeviceThunk(void* context, TF_DeviceContext_CopyTensorInSameDevice_Params* params) { - DeviceContext* device_context = static_cast(context); - Tensor *input_tensor = new Tensor(), *output_tensor = new Tensor(); - tsl::StatusCallback done = [params, input_tensor, - output_tensor](absl::Status status) { - delete input_tensor; - absl::Status tensor_status; - params->output_tensor = TF_TensorFromTensor(*output_tensor, &tensor_status); - delete output_tensor; - FromC(params->done)(tensor_status); - }; - absl::Status tensor_status; - tensor_status = TF_TensorToTensor(params->input_tensor, input_tensor); - if (!tensor_status.ok()) { - done(tensor_status); - return; - } - device_context->CopyTensorInSameDevice(input_tensor, /* device = */ nullptr, - output_tensor, std::move(done)); + LOG(FATAL) << "Unimplemented"; // Crash OK } TF_DeviceContext_CopyCPUTensorToDevice_Impl BindCpuToDevice(