Skip to content

Commit d5e65f7

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Allow custom sizes, dim order and strides for tensor view. (#14944)
Summary: . Differential Revision: D84259597
1 parent 7987db7 commit d5e65f7

File tree

2 files changed

+272
-32
lines changed

2 files changed

+272
-32
lines changed

extension/tensor/tensor_ptr.h

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -323,26 +323,46 @@ inline TensorPtr make_tensor_ptr(
323323
}
324324

325325
/**
326-
* Creates a TensorPtr to manage a new Tensor with the same properties
327-
* as the given Tensor, sharing the same data without owning it.
326+
* Creates a TensorPtr to manage a new Tensor that aliases the given Tensor's
327+
* storage, with optional metadata overrides. Shape dynamism is inherited from
328+
* the source tensor.
328329
*
329-
* @param tensor The Tensor whose properties are used to create a new TensorPtr.
330-
* @return A new TensorPtr managing a Tensor with the same properties as the
331-
* original.
330+
* If an override is provided (non-empty), it is passed as-is. If an override is
331+
* empty, the corresponding metadata is reused from the source tensor when it
332+
* fits; otherwise it is left empty for the core factory to derive a valid
333+
* configuration. If `dim_order` is empty but `strides` is provided, `dim_order`
334+
* is left empty so the core may infer it from the provided strides.
335+
*
336+
* @param tensor The source tensor to alias.
337+
* @param sizes Optional sizes override.
338+
* @param dim_order Optional dimension order override.
339+
* @param strides Optional strides override.
340+
* @return A TensorPtr aliasing the same storage with requested metadata.
332341
*/
333-
inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
334-
return make_tensor_ptr(
335-
std::vector<executorch::aten::SizesType>(
336-
tensor.sizes().begin(), tensor.sizes().end()),
337-
tensor.mutable_data_ptr(),
342+
inline TensorPtr make_tensor_ptr(
343+
const executorch::aten::Tensor& tensor,
344+
std::vector<executorch::aten::SizesType> sizes = {},
345+
std::vector<executorch::aten::DimOrderType> dim_order = {},
346+
std::vector<executorch::aten::StridesType> strides = {}) {
347+
if (sizes.empty()) {
348+
sizes.assign(tensor.sizes().begin(), tensor.sizes().end());
349+
}
350+
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
351+
const auto same_shape = same_rank &&
352+
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
338353
#ifndef USE_ATEN_LIB
339-
std::vector<executorch::aten::DimOrderType>(
340-
tensor.dim_order().begin(), tensor.dim_order().end()),
341-
#else // USE_ATEN_LIB
342-
{},
354+
if (dim_order.empty() && strides.empty() && same_rank) {
355+
dim_order.assign(tensor.dim_order().begin(), tensor.dim_order().end());
356+
}
343357
#endif // USE_ATEN_LIB
344-
std::vector<executorch::aten::StridesType>(
345-
tensor.strides().begin(), tensor.strides().end()),
358+
if (strides.empty() && same_shape) {
359+
strides.assign(tensor.strides().begin(), tensor.strides().end());
360+
}
361+
return make_tensor_ptr(
362+
std::move(sizes),
363+
tensor.mutable_data_ptr(),
364+
std::move(dim_order),
365+
std::move(strides),
346366
tensor.scalar_type()
347367
#ifndef USE_ATEN_LIB
348368
,
@@ -352,21 +372,21 @@ inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
352372
}
353373

354374
/**
355-
* Creates a TensorPtr to manage a new Tensor with the same properties
356-
* as the Tensor referenced by the given TensorPtr, sharing the same data
357-
* without owning it.
375+
* Convenience overload identical to make_tensor_ptr(*tensor_ptr, ...).
358376
*
359-
* This is a convenience overload equivalent to make_tensor_ptr(*tensor_ptr).
360-
* It does not extend the lifetime of the underlying buffer; if the original
361-
* owner releases the storage, all views aliasing it become dangling.
362-
*
363-
* @param tensor_ptr The TensorPtr whose underlying Tensor is used to initialize
364-
* the returned view.
365-
* @return A new TensorPtr managing a Tensor with the same properties as the
366-
* original.
377+
* @param tensor_ptr The source tensor pointer to alias.
378+
* @param sizes Optional sizes override.
379+
* @param dim_order Optional dimension order override.
380+
* @param strides Optional strides override.
381+
* @return A TensorPtr aliasing the same storage with requested metadata.
367382
*/
368-
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor_ptr) {
369-
return make_tensor_ptr(*tensor_ptr);
383+
inline TensorPtr make_tensor_ptr(
384+
const TensorPtr& tensor_ptr,
385+
std::vector<executorch::aten::SizesType> sizes = {},
386+
std::vector<executorch::aten::DimOrderType> dim_order = {},
387+
std::vector<executorch::aten::StridesType> strides = {}) {
388+
return make_tensor_ptr(
389+
*tensor_ptr, std::move(sizes), std::move(dim_order), std::move(strides));
370390
}
371391

372392
/**

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 223 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,226 @@ TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt32) {
357357
EXPECT_EQ(new_tensor->scalar_type(), executorch::aten::ScalarType::Int);
358358
}
359359

360+
TEST_F(TensorPtrTest, MakeViewOverrideSizesRankIncrease) {
361+
std::vector<float> data = {1, 2, 3, 4, 5, 6};
362+
auto tensor = make_tensor_ptr({2, 3}, std::move(data));
363+
auto view = make_tensor_ptr(tensor, {1, 2, 3});
364+
365+
EXPECT_EQ(view->dim(), 3);
366+
EXPECT_EQ(view->size(0), 1);
367+
EXPECT_EQ(view->size(1), 2);
368+
EXPECT_EQ(view->size(2), 3);
369+
EXPECT_EQ(view->const_data_ptr<float>(), tensor->const_data_ptr<float>());
370+
EXPECT_EQ(view->strides()[0], 6);
371+
EXPECT_EQ(view->strides()[1], 3);
372+
EXPECT_EQ(view->strides()[2], 1);
373+
}
374+
375+
TEST_F(TensorPtrTest, MakeViewOverrideSizesSameRankRecomputesStrides) {
376+
float data[12] = {0};
377+
auto tensor = make_tensor_ptr({3, 4}, data);
378+
auto view = make_tensor_ptr(tensor, {4, 3});
379+
380+
EXPECT_EQ(view->dim(), 2);
381+
EXPECT_EQ(view->size(0), 4);
382+
EXPECT_EQ(view->size(1), 3);
383+
EXPECT_EQ(view->strides()[0], 3);
384+
EXPECT_EQ(view->strides()[1], 1);
385+
}
386+
387+
TEST_F(TensorPtrTest, MakeViewOverrideDimOrderOnly) {
388+
float data[6] = {0};
389+
auto tensor = make_tensor_ptr({2, 3}, data);
390+
auto view = make_tensor_ptr(tensor, {}, {1, 0}, {});
391+
392+
EXPECT_EQ(view->dim(), 2);
393+
EXPECT_EQ(view->size(0), 2);
394+
EXPECT_EQ(view->size(1), 3);
395+
EXPECT_EQ(view->strides()[0], 1);
396+
EXPECT_EQ(view->strides()[1], 2);
397+
}
398+
399+
TEST_F(TensorPtrTest, MakeViewOverrideStridesOnlyInfersDimOrder) {
400+
float data[12] = {0};
401+
auto tensor = make_tensor_ptr({3, 4}, data);
402+
auto view = make_tensor_ptr(tensor, {}, {}, {1, 3});
403+
404+
EXPECT_EQ(view->dim(), 2);
405+
EXPECT_EQ(view->size(0), 3);
406+
EXPECT_EQ(view->size(1), 4);
407+
EXPECT_EQ(view->strides()[0], 1);
408+
EXPECT_EQ(view->strides()[1], 3);
409+
}
410+
411+
TEST_F(TensorPtrTest, MakeViewReuseMetadataWhenShapeSame) {
412+
float data[12] = {0};
413+
auto tensor = make_tensor_ptr({3, 4}, data, {1, 0}, {1, 3});
414+
auto view = make_tensor_ptr(tensor, {3, 4});
415+
416+
EXPECT_EQ(view->dim(), 2);
417+
EXPECT_EQ(view->size(0), 3);
418+
EXPECT_EQ(view->size(1), 4);
419+
EXPECT_EQ(view->strides()[0], 1);
420+
EXPECT_EQ(view->strides()[1], 3);
421+
}
422+
423+
TEST_F(TensorPtrTest, MakeViewShapeChangeWithExplicitOldStridesExpectDeath) {
424+
float data[12] = {0};
425+
auto tensor = make_tensor_ptr({3, 4}, data);
426+
std::vector<executorch::aten::StridesType> old_strides(
427+
tensor->strides().begin(), tensor->strides().end());
428+
429+
ET_EXPECT_DEATH(
430+
{ auto _ = make_tensor_ptr(tensor, {2, 6}, {}, old_strides); }, "");
431+
}
432+
433+
TEST_F(TensorPtrTest, MakeViewInvalidDimOrderExpectDeath) {
434+
float data[12] = {0};
435+
auto tensor = make_tensor_ptr({3, 4}, data);
436+
437+
ET_EXPECT_DEATH(
438+
{ auto _ = make_tensor_ptr(tensor, {3, 4}, {2, 1}, {1, 4}); }, "");
439+
}
440+
441+
TEST_F(TensorPtrTest, MakeViewFromTensorPtrConvenienceOverload) {
442+
float data[12] = {0};
443+
auto tensor = make_tensor_ptr({3, 4}, data);
444+
auto view = make_tensor_ptr(tensor, {}, {1, 0}, {});
445+
446+
EXPECT_EQ(view->dim(), 2);
447+
EXPECT_EQ(view->size(0), 3);
448+
EXPECT_EQ(view->size(1), 4);
449+
EXPECT_EQ(view->strides()[0], 1);
450+
EXPECT_EQ(view->strides()[1], 3);
451+
}
452+
453+
TEST_F(TensorPtrTest, MakeViewRankDecreaseFlatten) {
454+
float data[6] = {1, 2, 3, 4, 5, 6};
455+
auto tensor = make_tensor_ptr(
456+
{2, 3},
457+
data,
458+
{},
459+
{},
460+
executorch::aten::ScalarType::Float,
461+
executorch::aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
462+
auto view = make_tensor_ptr(tensor, {6});
463+
EXPECT_EQ(view->dim(), 1);
464+
EXPECT_EQ(view->size(0), 6);
465+
EXPECT_EQ(view->strides()[0], 1);
466+
EXPECT_NE(tensor->unsafeGetTensorImpl(), view->unsafeGetTensorImpl());
467+
EXPECT_EQ(resize_tensor_ptr(view, {3, 2}), Error::Ok);
468+
EXPECT_EQ(view->dim(), 2);
469+
EXPECT_EQ(view->size(0), 3);
470+
EXPECT_EQ(view->size(1), 2);
471+
EXPECT_EQ(tensor->size(0), 2);
472+
EXPECT_EQ(tensor->size(1), 3);
473+
}
474+
475+
TEST_F(TensorPtrTest, MakeViewFromScalarAliasAnd1D) {
476+
float scalar_value = 7.f;
477+
auto tensor = make_tensor_ptr({}, &scalar_value);
478+
auto alias = make_tensor_ptr(tensor);
479+
EXPECT_EQ(alias->dim(), 0);
480+
EXPECT_EQ(alias->numel(), 1);
481+
auto reshaped = make_tensor_ptr(tensor, {1});
482+
EXPECT_EQ(reshaped->dim(), 1);
483+
EXPECT_EQ(reshaped->size(0), 1);
484+
EXPECT_EQ(reshaped->strides()[0], 1);
485+
ET_EXPECT_DEATH({ auto unused = make_tensor_ptr(tensor, {}, {0}, {}); }, "");
486+
ET_EXPECT_DEATH({ auto unused = make_tensor_ptr(tensor, {}, {}, {1}); }, "");
487+
}
488+
489+
TEST_F(TensorPtrTest, MakeViewFromZeroSizeNonNullDataToOneElement) {
490+
float data[1] = {123.0f};
491+
auto tensor = make_tensor_ptr({0}, data);
492+
auto view = make_tensor_ptr(tensor, {1});
493+
494+
EXPECT_EQ(view->dim(), 1);
495+
EXPECT_EQ(view->size(0), 1);
496+
EXPECT_EQ(view->numel(), 1);
497+
498+
EXPECT_EQ(view->const_data_ptr<float>(), data);
499+
EXPECT_EQ(view->strides().size(), 1);
500+
EXPECT_EQ(view->strides()[0], 1);
501+
502+
EXPECT_EQ(view->const_data_ptr<float>()[0], 123.0f);
503+
504+
view->mutable_data_ptr<float>()[0] = 456.0f;
505+
EXPECT_EQ(data[0], 456.0f);
506+
}
507+
508+
TEST_F(TensorPtrTest, MakeViewExplicitDimOrderAndStridesShapeChange) {
509+
float data[6] = {0};
510+
auto tensor = make_tensor_ptr({2, 3}, data);
511+
auto view = make_tensor_ptr(tensor, {3, 2}, {1, 0}, {1, 3});
512+
EXPECT_EQ(view->dim(), 2);
513+
EXPECT_EQ(view->size(0), 3);
514+
EXPECT_EQ(view->size(1), 2);
515+
EXPECT_EQ(view->strides()[0], 1);
516+
EXPECT_EQ(view->strides()[1], 3);
517+
}
518+
519+
TEST_F(TensorPtrTest, TensorUint8dataInt16Type) {
520+
std::vector<int16_t> int16_values = {-1, 2, -3, 4};
521+
auto byte_pointer = reinterpret_cast<const uint8_t*>(int16_values.data());
522+
std::vector<uint8_t> byte_data(
523+
byte_pointer, byte_pointer + int16_values.size() * sizeof(int16_t));
524+
auto tensor = make_tensor_ptr(
525+
{4}, std::move(byte_data), executorch::aten::ScalarType::Short);
526+
EXPECT_EQ(tensor->dim(), 1);
527+
EXPECT_EQ(tensor->size(0), 4);
528+
auto int16_data = tensor->const_data_ptr<int16_t>();
529+
EXPECT_EQ(int16_data[0], -1);
530+
EXPECT_EQ(int16_data[1], 2);
531+
EXPECT_EQ(int16_data[2], -3);
532+
EXPECT_EQ(int16_data[3], 4);
533+
}
534+
535+
TEST_F(TensorPtrTest, MakeView3DDimOrderOnly) {
536+
float data[24] = {0};
537+
auto tensor = make_tensor_ptr({2, 3, 4}, data);
538+
auto view = make_tensor_ptr(tensor, {}, {2, 0, 1}, {});
539+
EXPECT_EQ(view->dim(), 3);
540+
EXPECT_EQ(view->size(0), 2);
541+
EXPECT_EQ(view->size(1), 3);
542+
EXPECT_EQ(view->size(2), 4);
543+
EXPECT_EQ(view->strides()[0], 3);
544+
EXPECT_EQ(view->strides()[1], 1);
545+
EXPECT_EQ(view->strides()[2], 6);
546+
}
547+
548+
#ifndef USE_ATEN_LIB
549+
TEST_F(TensorPtrTest, MakeViewDynamismPropagationResizeAlias) {
550+
float data[12] = {0};
551+
auto tensor = make_tensor_ptr(
552+
{3, 4},
553+
data,
554+
{},
555+
{},
556+
executorch::aten::ScalarType::Float,
557+
executorch::aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
558+
auto alias = make_tensor_ptr(tensor);
559+
EXPECT_EQ(resize_tensor_ptr(alias, {2, 6}), Error::Ok);
560+
EXPECT_EQ(alias->size(0), 2);
561+
EXPECT_EQ(alias->size(1), 6);
562+
EXPECT_EQ(tensor->size(0), 3);
563+
EXPECT_EQ(tensor->size(1), 4);
564+
}
565+
566+
TEST_F(TensorPtrTest, MakeViewSameRankShapeChangeCopiesDimOrder) {
567+
float data[24] = {0};
568+
auto tensor = make_tensor_ptr({2, 3, 4}, data, {2, 0, 1}, {3, 1, 6});
569+
auto view = make_tensor_ptr(tensor, {4, 2, 3});
570+
EXPECT_EQ(view->dim(), 3);
571+
EXPECT_EQ(view->size(0), 4);
572+
EXPECT_EQ(view->size(1), 2);
573+
EXPECT_EQ(view->size(2), 3);
574+
EXPECT_EQ(view->strides()[0], 2);
575+
EXPECT_EQ(view->strides()[1], 1);
576+
EXPECT_EQ(view->strides()[2], 8);
577+
}
578+
#endif
579+
360580
TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
361581
std::vector<int32_t> data = {1, 2, 3, 4};
362582
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
@@ -803,7 +1023,7 @@ TEST_F(TensorPtrTest, TensorDeducedScalarType) {
8031023
EXPECT_EQ(tensor->const_data_ptr<double>()[3], 4.0);
8041024
}
8051025

806-
TEST_F(TensorPtrTest, TensorUint8BufferWithFloatScalarType) {
1026+
TEST_F(TensorPtrTest, TensorUint8dataWithFloatScalarType) {
8071027
std::vector<uint8_t> data(
8081028
4 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8091029

@@ -827,14 +1047,14 @@ TEST_F(TensorPtrTest, TensorUint8BufferWithFloatScalarType) {
8271047
EXPECT_EQ(tensor->const_data_ptr<float>()[3], 4.0f);
8281048
}
8291049

830-
TEST_F(TensorPtrTest, TensorUint8BufferTooSmallExpectDeath) {
1050+
TEST_F(TensorPtrTest, TensorUint8dataTooSmallExpectDeath) {
8311051
std::vector<uint8_t> data(
8321052
2 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8331053
ET_EXPECT_DEATH(
8341054
{ auto tensor = make_tensor_ptr({2, 2}, std::move(data)); }, "");
8351055
}
8361056

837-
TEST_F(TensorPtrTest, TensorUint8BufferTooLargeExpectDeath) {
1057+
TEST_F(TensorPtrTest, TensorUint8dataTooLargeExpectDeath) {
8381058
std::vector<uint8_t> data(
8391059
5 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8401060
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, "");

0 commit comments

Comments
 (0)