@@ -682,6 +682,95 @@ AOTITorchError aoti_torch__reinterpret_tensor(
682682 return Error::Ok;
683683}
684684
685+ AOTITorchError aoti_torch_new_tensor_handle (
686+ Tensor* orig_handle,
687+ Tensor** new_handle) {
688+ // Validate input parameters
689+ ET_CHECK_OR_RETURN_ERROR (
690+ orig_handle != nullptr ,
691+ InvalidArgument,
692+ " aoti_torch_new_tensor_handle failed: orig_handle is null" );
693+
694+ ET_CHECK_OR_RETURN_ERROR (
695+ new_handle != nullptr ,
696+ InvalidArgument,
697+ " aoti_torch_new_tensor_handle failed: new_handle is null" );
698+
699+ // Get metadata from the original tensor
700+ int64_t * sizes_ptr;
701+ int64_t * strides_ptr;
702+ int32_t dtype;
703+ int32_t device_type;
704+ int32_t device_index;
705+
706+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_sizes (orig_handle, &sizes_ptr));
707+ ET_CHECK_OK_OR_RETURN_ERROR (
708+ aoti_torch_get_strides (orig_handle, &strides_ptr));
709+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_dtype (orig_handle, &dtype));
710+ ET_CHECK_OK_OR_RETURN_ERROR (
711+ aoti_torch_get_device_type (orig_handle, &device_type));
712+ ET_CHECK_OK_OR_RETURN_ERROR (
713+ aoti_torch_get_device_index (orig_handle, &device_index));
714+
715+ int64_t ndim = orig_handle->dim ();
716+
717+ // Validate dtype
718+ ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (dtype));
719+
720+ // Ensure device_index is always 0
721+ ET_CHECK_OR_RETURN_ERROR (
722+ device_index == 0 ,
723+ InvalidArgument,
724+ " device_index must be 0, got: %d" ,
725+ device_index);
726+
727+ // Get the original data pointer from the source tensor
728+ void * data_ptr = orig_handle->mutable_data_ptr ();
729+ ET_CHECK_OR_RETURN_ERROR (
730+ data_ptr != nullptr ,
731+ InvalidArgument,
732+ " Source tensor has null data pointer" );
733+
734+ // Check if the given memory is in the map
735+ auto memory_it = memory_to_n_tensor.find (data_ptr);
736+ ET_CHECK_OR_RETURN_ERROR (
737+ memory_it != memory_to_n_tensor.end (),
738+ InvalidArgument,
739+ " Memory address %p is not being tracked by reference counting system" ,
740+ data_ptr);
741+
742+ // Convert sizes and strides to vectors
743+ std::vector<SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
744+ std::vector<StridesType> strides =
745+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
746+
747+ // Create new tensor that shares the same memory as the original
748+ // This is similar to PyTorch's Tensor copy constructor - creates a new
749+ // tensor object that shares the same underlying storage
750+ std::shared_ptr<Tensor> tensor = make_tensor (
751+ sizes, // Same sizes as original
752+ data_ptr, // Share the same memory from source tensor
753+ {}, // dim_order (empty, will be auto-generated)
754+ strides, // Same strides as original
755+ dtype_to_scalar_type (dtype) // Same dtype as original
756+ );
757+
758+ ET_CHECK_OR_RETURN_ERROR (
759+ tensor != nullptr , InvalidArgument, " Failed to create new tensor handle" );
760+
761+ // Store the tensor so it doesn't get destroyed
762+ tensors.insert (tensor);
763+
764+ *new_handle = tensor.get ();
765+
766+ // Increment the reference count for this memory address only if it is owned
767+ // by tensor
768+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
769+ ? NOT_OWN
770+ : memory_to_n_tensor[data_ptr] + 1 ;
771+
772+ return Error::Ok;
773+ }
685774} // extern "C"
686775
687776} // namespace executorch::backends::cuda
0 commit comments