Skip to content

Commit

Permalink
Copy tensor content through the API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558308631
  • Loading branch information
hhb authored and tensorflower-gardener committed Aug 19, 2023
1 parent 9db029d commit c9fafed
Showing 1 changed file with 46 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ class TfCThunkDeviceContext final : public DeviceContext {

done = [params, done = std::move(done),
device_tensor](const absl::Status& status) -> void {
absl::Status tensor_status = status;
if (tensor_status.ok()) {
tensor_status = TF_TensorToTensor(params->device_tensor, device_tensor);
}
done(tensor_status);
// TODO: Find a way to convert device tensor.
done(status);
Destroy(params);
delete params;
};
Expand Down Expand Up @@ -207,74 +204,79 @@ class TfCThunkDeviceContext final : public DeviceContext {
const TF_DeviceContext thunk_;
};

void CopyTF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
// TODO: Convert through a lookup table for better API compatibility.
DataType dtype = static_cast<DataType>(TF_TensorType(src));
TensorShape tensor_shape;
int dim = TF_NumDims(src);
for (int i = 0; i < dim; ++i) {
tensor_shape.AddDim(TF_Dim(src, i));
}
*dst = Tensor(dtype, tensor_shape);

std::memcpy(dst->data(), TF_TensorData(src), TF_TensorByteSize(src));
}

void CpuToDeviceThunk(void* context,
TF_DeviceContext_CopyCPUTensorToDevice_Params* params) {
DeviceContext* device_context = static_cast<DeviceContext*>(context);
Tensor *cpu_tensor = new Tensor(), *device_tensor = new Tensor();
Tensor* cpu_tensor = new Tensor();
Tensor* device_tensor = new Tensor();
tsl::StatusCallback done = [params, device_tensor,
cpu_tensor](absl::Status status) {
delete cpu_tensor;
absl::Status tensor_status;
params->device_tensor = TF_TensorFromTensor(*device_tensor, &tensor_status);
// TODO: find a way to convert device tensor.
// params->device_tensor = TF_TensorFromTensor(*device_tensor,
// &tensor_status);
delete device_tensor;
FromC(params->done)(tensor_status);
};
absl::Status tensor_status;
tensor_status = TF_TensorToTensor(params->cpu_tensor, cpu_tensor);
if (!tensor_status.ok()) {
done(tensor_status);
return;
}
CopyTF_TensorToTensor(params->cpu_tensor, cpu_tensor);
bool sync_dst_compute = params->sync_dst_compute;
device_context->CopyCPUTensorToDevice(cpu_tensor, /* device = */ nullptr,
device_tensor, std::move(done),
sync_dst_compute);
}

TF_Tensor* CopyTensorToTF_Tensor(const Tensor& src) {
// TODO: Convert through a lookup table for better API compatibility.
TF_DataType dtype = static_cast<TF_DataType>(src.dtype());
const TensorShape& shape = src.shape();
int64_t* dims = new int64_t[shape.dims()];
size_t len = TF_DataTypeSize(dtype);
for (int i = 0; i < shape.dims(); ++i) {
dims[i] = shape.dim_size(i);
len *= dims[i];
}
TF_Tensor* tf_tensor = TF_AllocateTensor(dtype, dims, shape.dims(), len);
void* data = TF_TensorData(tf_tensor);
std::memcpy(data, src.data(), len);
return tf_tensor;
}

void DeviceToCpuThunk(void* context,
TF_DeviceContext_CopyDeviceTensorToCPU_Params* params) {
DeviceContext* device_context = static_cast<DeviceContext*>(context);
Tensor *cpu_tensor = new Tensor(), *device_tensor = new Tensor();
Tensor* cpu_tensor = new Tensor();
Tensor* device_tensor = new Tensor();
tsl::StatusCallback done = [params, device_tensor,
cpu_tensor](absl::Status status) {
delete device_tensor;
absl::Status tensor_status;
params->cpu_tensor = TF_TensorFromTensor(*cpu_tensor, &tensor_status);
params->cpu_tensor = CopyTensorToTF_Tensor(*cpu_tensor);
delete cpu_tensor;
FromC(params->done)(tensor_status);
FromC(params->done)(status);
};
absl::Status tensor_status;
tensor_status = TF_TensorToTensor(params->device_tensor, device_tensor);
if (!tensor_status.ok()) {
done(tensor_status);
return;
}
std::string_view tensor_name(params->tensor_name, params->tensor_name_len);
device_context->CopyDeviceTensorToCPU(device_tensor, tensor_name,
/* device = */ nullptr, cpu_tensor,
std::move(done));
// TODO: Find a way to convert device tensor.
device_context->CopyDeviceTensorToCPU(
/* device_tensor = */ nullptr, tensor_name,
/* device = */ nullptr, cpu_tensor, std::move(done));
}

void SameDeviceThunk(void* context,
TF_DeviceContext_CopyTensorInSameDevice_Params* params) {
DeviceContext* device_context = static_cast<DeviceContext*>(context);
Tensor *input_tensor = new Tensor(), *output_tensor = new Tensor();
tsl::StatusCallback done = [params, input_tensor,
output_tensor](absl::Status status) {
delete input_tensor;
absl::Status tensor_status;
params->output_tensor = TF_TensorFromTensor(*output_tensor, &tensor_status);
delete output_tensor;
FromC(params->done)(tensor_status);
};
absl::Status tensor_status;
tensor_status = TF_TensorToTensor(params->input_tensor, input_tensor);
if (!tensor_status.ok()) {
done(tensor_status);
return;
}
device_context->CopyTensorInSameDevice(input_tensor, /* device = */ nullptr,
output_tensor, std::move(done));
LOG(FATAL) << "Unimplemented"; // Crash OK
}

TF_DeviceContext_CopyCPUTensorToDevice_Impl BindCpuToDevice(
Expand Down

0 comments on commit c9fafed

Please sign in to comment.