Skip to content

Commit 7987db7

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add overload to create atensor view from TensorPtr. (#14943)
Summary: . Reviewed By: larryliu0820 Differential Revision: D84259596
1 parent bdc526b commit 7987db7

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

extension/tensor/tensor_ptr.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,19 +338,37 @@ inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
338338
#ifndef USE_ATEN_LIB
339339
std::vector<executorch::aten::DimOrderType>(
340340
tensor.dim_order().begin(), tensor.dim_order().end()),
341-
std::vector<executorch::aten::StridesType>(
342-
tensor.strides().begin(), tensor.strides().end()),
343-
tensor.scalar_type(),
344-
tensor.shape_dynamism()
345341
#else // USE_ATEN_LIB
346342
{},
343+
#endif // USE_ATEN_LIB
347344
std::vector<executorch::aten::StridesType>(
348345
tensor.strides().begin(), tensor.strides().end()),
349346
tensor.scalar_type()
347+
#ifndef USE_ATEN_LIB
348+
,
349+
tensor.shape_dynamism()
350350
#endif // USE_ATEN_LIB
351351
);
352352
}
353353

354+
/**
355+
* Creates a TensorPtr to manage a new Tensor with the same properties
356+
* as the Tensor referenced by the given TensorPtr, sharing the same data
357+
* without owning it.
358+
*
359+
* This is a convenience overload equivalent to make_tensor_ptr(*tensor_ptr).
360+
* It does not extend the lifetime of the underlying buffer; if the original
361+
* owner releases the storage, all views aliasing it become dangling.
362+
*
363+
* @param tensor_ptr The TensorPtr whose underlying Tensor is used to initialize
364+
* the returned view.
365+
* @return A new TensorPtr managing a Tensor with the same properties as the
366+
* original.
367+
*/
368+
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor_ptr) {
369+
return make_tensor_ptr(*tensor_ptr);
370+
}
371+
354372
/**
355373
* Creates a TensorPtr that manages a new Tensor with the same properties
356374
* as the given Tensor, but with a copy of the data owned by the returned

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ TEST_F(TensorPtrTest, TensorSharingImplResizingAffectsBothVector) {
347347
TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt32) {
348348
std::vector<int32_t> data = {1, 2, 3, 4};
349349
auto tensor = make_tensor_ptr({2, 2}, data);
350-
auto new_tensor = make_tensor_ptr(*tensor);
350+
auto new_tensor = make_tensor_ptr(tensor);
351351

352352
EXPECT_EQ(new_tensor->dim(), tensor->dim());
353353
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
@@ -360,7 +360,7 @@ TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt32) {
360360
TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
361361
std::vector<int32_t> data = {1, 2, 3, 4};
362362
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
363-
auto cloned_tensor = clone_tensor_ptr(*tensor);
363+
auto cloned_tensor = clone_tensor_ptr(tensor);
364364

365365
EXPECT_EQ(cloned_tensor->dim(), tensor->dim());
366366
EXPECT_EQ(cloned_tensor->size(0), tensor->size(0));
@@ -373,6 +373,56 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
373373
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int);
374374
}
375375

376+
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrInt32) {
377+
std::vector<int32_t> data = {1, 2, 3, 4};
378+
auto tensor = make_tensor_ptr({2, 2}, data);
379+
auto new_tensor = make_tensor_ptr(tensor);
380+
381+
EXPECT_EQ(new_tensor->dim(), tensor->dim());
382+
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
383+
EXPECT_EQ(new_tensor->size(1), tensor->size(1));
384+
EXPECT_EQ(
385+
new_tensor->const_data_ptr<int32_t>(), tensor->const_data_ptr<int32_t>());
386+
EXPECT_EQ(new_tensor->scalar_type(), executorch::aten::ScalarType::Int);
387+
}
388+
389+
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrDouble) {
390+
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
391+
auto tensor = make_tensor_ptr({2, 2}, data);
392+
auto new_tensor = make_tensor_ptr(tensor);
393+
394+
EXPECT_EQ(new_tensor->dim(), tensor->dim());
395+
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
396+
EXPECT_EQ(new_tensor->size(1), tensor->size(1));
397+
EXPECT_EQ(
398+
new_tensor->const_data_ptr<double>(), tensor->const_data_ptr<double>());
399+
EXPECT_EQ(new_tensor->scalar_type(), executorch::aten::ScalarType::Double);
400+
}
401+
402+
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrInt64) {
403+
std::vector<int64_t> data = {100, 200, 300, 400};
404+
auto tensor = make_tensor_ptr({2, 2}, data);
405+
auto new_tensor = make_tensor_ptr(tensor);
406+
407+
EXPECT_EQ(new_tensor->dim(), tensor->dim());
408+
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
409+
EXPECT_EQ(new_tensor->size(1), tensor->size(1));
410+
EXPECT_EQ(
411+
new_tensor->const_data_ptr<int64_t>(), tensor->const_data_ptr<int64_t>());
412+
EXPECT_EQ(new_tensor->scalar_type(), executorch::aten::ScalarType::Long);
413+
}
414+
415+
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrNull) {
416+
auto tensor = make_tensor_ptr({2, 2}, nullptr);
417+
auto new_tensor = make_tensor_ptr(tensor);
418+
419+
EXPECT_EQ(new_tensor->dim(), tensor->dim());
420+
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
421+
EXPECT_EQ(new_tensor->size(1), tensor->size(1));
422+
EXPECT_EQ(new_tensor->const_data_ptr(), tensor->const_data_ptr());
423+
EXPECT_EQ(new_tensor->const_data_ptr(), nullptr);
424+
}
425+
376426
TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrInt32) {
377427
std::vector<int32_t> data = {1, 2, 3, 4};
378428
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
@@ -392,7 +442,7 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrInt32) {
392442
TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorDouble) {
393443
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
394444
auto tensor = make_tensor_ptr({2, 2}, data);
395-
auto new_tensor = make_tensor_ptr(*tensor);
445+
auto new_tensor = make_tensor_ptr(tensor);
396446

397447
EXPECT_EQ(new_tensor->dim(), tensor->dim());
398448
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
@@ -405,7 +455,7 @@ TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorDouble) {
405455
TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorDouble) {
406456
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
407457
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
408-
auto cloned_tensor = clone_tensor_ptr(*tensor);
458+
auto cloned_tensor = clone_tensor_ptr(tensor);
409459

410460
EXPECT_EQ(cloned_tensor->dim(), tensor->dim());
411461
EXPECT_EQ(cloned_tensor->size(0), tensor->size(0));
@@ -437,7 +487,7 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrDouble) {
437487
TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt64) {
438488
std::vector<int64_t> data = {100, 200, 300, 400};
439489
auto tensor = make_tensor_ptr({2, 2}, data);
440-
auto new_tensor = make_tensor_ptr(*tensor);
490+
auto new_tensor = make_tensor_ptr(tensor);
441491

442492
EXPECT_EQ(new_tensor->dim(), tensor->dim());
443493
EXPECT_EQ(new_tensor->size(0), tensor->size(0));
@@ -450,7 +500,7 @@ TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt64) {
450500
TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt64) {
451501
std::vector<int64_t> data = {100, 200, 300, 400};
452502
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
453-
auto cloned_tensor = clone_tensor_ptr(*tensor);
503+
auto cloned_tensor = clone_tensor_ptr(tensor);
454504

455505
EXPECT_EQ(cloned_tensor->dim(), tensor->dim());
456506
EXPECT_EQ(cloned_tensor->size(0), tensor->size(0));

0 commit comments

Comments
 (0)