Skip to content

Commit bee30ac

Browse files
introduce aoti_torch_new_tensor_handle shim for cuda backend (#15861)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15857 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/65/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/65/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/65/orig Differential Revision: [D87254968](https://our.internmc.facebook.com/intern/diff/D87254968/) @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]>
1 parent 7c746f7 commit bee30ac

File tree

8 files changed

+688
-11
lines changed

8 files changed

+688
-11
lines changed

backends/aoti/common_shims.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,6 @@ aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
238238
return Error::Internal;
239239
}
240240

241-
AOTI_SHIM_EXPORT AOTITorchError
242-
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle) {
243-
(void)orig_handle;
244-
(void)new_handle;
245-
throw std::runtime_error("Not implemented");
246-
return Error::Internal;
247-
}
248-
249241
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
250242
void* data_ptr,
251243
int64_t ndim,

backends/aoti/common_shims.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor);
9494
AOTI_SHIM_EXPORT AOTITorchError
9595
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor);
9696

97-
AOTI_SHIM_EXPORT AOTITorchError
98-
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
99-
10097
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
10198
void* data_ptr,
10299
int64_t ndim,

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,15 @@ AOTITorchError aoti_torch__reinterpret_tensor(
506506
return Error::Ok;
507507
}
508508

509+
AOTITorchError aoti_torch_new_tensor_handle(
510+
Tensor* orig_handle,
511+
Tensor** new_handle) {
512+
(void)orig_handle;
513+
(void)new_handle;
514+
throw std::runtime_error("Not implemented");
515+
return Error::Internal;
516+
}
517+
509518
// Cleanup function for clearing global state
510519
void cleanup_memory() {
511520
// Use aoti_torch_delete_tensor_object to properly delete each tensor

backends/apple/metal/runtime/shims/memory.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ AOTITorchError aoti_torch__reinterpret_tensor(
6464
int64_t storage_offset,
6565
AOTITensorHandle* ret_new_tensor);
6666

67+
AOTITorchError aoti_torch_new_tensor_handle(
68+
Tensor* orig_handle,
69+
Tensor** new_handle);
70+
6771
void cleanup_memory();
6872

6973
} // extern "C"

backends/cuda/runtime/shims/memory.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,95 @@ AOTITorchError aoti_torch__reinterpret_tensor(
682682
return Error::Ok;
683683
}
684684

685+
AOTITorchError aoti_torch_new_tensor_handle(
686+
Tensor* orig_handle,
687+
Tensor** new_handle) {
688+
// Validate input parameters
689+
ET_CHECK_OR_RETURN_ERROR(
690+
orig_handle != nullptr,
691+
InvalidArgument,
692+
"aoti_torch_new_tensor_handle failed: orig_handle is null");
693+
694+
ET_CHECK_OR_RETURN_ERROR(
695+
new_handle != nullptr,
696+
InvalidArgument,
697+
"aoti_torch_new_tensor_handle failed: new_handle is null");
698+
699+
// Get metadata from the original tensor
700+
int64_t* sizes_ptr;
701+
int64_t* strides_ptr;
702+
int32_t dtype;
703+
int32_t device_type;
704+
int32_t device_index;
705+
706+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr));
707+
ET_CHECK_OK_OR_RETURN_ERROR(
708+
aoti_torch_get_strides(orig_handle, &strides_ptr));
709+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype));
710+
ET_CHECK_OK_OR_RETURN_ERROR(
711+
aoti_torch_get_device_type(orig_handle, &device_type));
712+
ET_CHECK_OK_OR_RETURN_ERROR(
713+
aoti_torch_get_device_index(orig_handle, &device_index));
714+
715+
int64_t ndim = orig_handle->dim();
716+
717+
// Validate dtype
718+
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));
719+
720+
// Ensure device_index is always 0
721+
ET_CHECK_OR_RETURN_ERROR(
722+
device_index == 0,
723+
InvalidArgument,
724+
"device_index must be 0, got: %d",
725+
device_index);
726+
727+
// Get the original data pointer from the source tensor
728+
void* data_ptr = orig_handle->mutable_data_ptr();
729+
ET_CHECK_OR_RETURN_ERROR(
730+
data_ptr != nullptr,
731+
InvalidArgument,
732+
"Source tensor has null data pointer");
733+
734+
// Check if the given memory is in the map
735+
auto memory_it = memory_to_n_tensor.find(data_ptr);
736+
ET_CHECK_OR_RETURN_ERROR(
737+
memory_it != memory_to_n_tensor.end(),
738+
InvalidArgument,
739+
"Memory address %p is not being tracked by reference counting system",
740+
data_ptr);
741+
742+
// Convert sizes and strides to vectors
743+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
744+
std::vector<StridesType> strides =
745+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
746+
747+
// Create new tensor that shares the same memory as the original
748+
// This is similar to PyTorch's Tensor copy constructor - creates a new
749+
// tensor object that shares the same underlying storage
750+
std::shared_ptr<Tensor> tensor = make_tensor(
751+
sizes, // Same sizes as original
752+
data_ptr, // Share the same memory from source tensor
753+
{}, // dim_order (empty, will be auto-generated)
754+
strides, // Same strides as original
755+
dtype_to_scalar_type(dtype) // Same dtype as original
756+
);
757+
758+
ET_CHECK_OR_RETURN_ERROR(
759+
tensor != nullptr, InvalidArgument, "Failed to create new tensor handle");
760+
761+
// Store the tensor so it doesn't get destroyed
762+
tensors.insert(tensor);
763+
764+
*new_handle = tensor.get();
765+
766+
// Increment the reference count for this memory address only if it is owned
767+
// by tensor
768+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
769+
? NOT_OWN
770+
: memory_to_n_tensor[data_ptr] + 1;
771+
772+
return Error::Ok;
773+
}
685774
} // extern "C"
686775

687776
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,31 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
140140
AOTI_SHIM_EXPORT AOTITorchError
141141
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
142142

143+
/**
144+
* Creates a new tensor handle from an existing one.
145+
*
146+
* This function creates a new tensor object that shares the same underlying
147+
* memory as the original tensor. Similar to PyTorch's Tensor copy constructor,
148+
* it creates a new handle/reference to the same data without performing a deep
149+
* copy.
150+
*
151+
* The new tensor will:
152+
* - Share the same memory/storage as the original tensor
153+
* - Have the same shape, strides, and dtype as the original
154+
* - Increment the reference count for the underlying memory (if owned)
155+
*
156+
* @param orig_handle Original tensor to create a new handle from (must not be
157+
* null)
158+
* @param new_handle Output pointer to store the new tensor handle (must not be
159+
* null)
160+
*
161+
* @return Error::Ok on success, appropriate error code on failure:
162+
* - Error::InvalidArgument: null pointers or invalid parameters
163+
*/
164+
AOTITorchError aoti_torch_new_tensor_handle(
165+
Tensor* orig_handle,
166+
Tensor** new_handle);
167+
143168
// Function to clear all tensors from internal storage
144169
AOTI_SHIM_EXPORT void clear_all_tensors();
145170
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ def define_common_targets():
3434
cuda_shim_cpp_unittest("aoti_torch_copy_")
3535
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
3636
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
37+
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")

0 commit comments

Comments
 (0)