Skip to content

Commit 98c50ce

Browse files
姜怡文it-is-a-robot
姜怡文
authored andcommitted
!13038 Separate tensor.resize_ & storage.resize_
Merge pull request !13038 from 姜怡文/master_resize
1 parent f31b1de commit 98c50ce

File tree

3 files changed

+32
-19
lines changed

3 files changed

+32
-19
lines changed

test/npu/test_combine_tensors.py

+6
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def test_storage_resize(self, dtype, device="npu"):
8181
self.assertEqual(tensor_new.storage(), tensor.storage())
8282
idx += torch_npu.get_storage_size(tensor)
8383

84+
@Dtypes(torch.half, torch.float, torch.bfloat16)
85+
def test_untyped_storage_resize(self, dtype, device="npu"):
86+
a = torch.randn((2, 2, 2, 2, 2), device=device, dtype=dtype)
87+
a.untyped_storage().resize_(0)
88+
a.untyped_storage().resize_(128)
89+
8490
@Dtypes(torch.half, torch.float, torch.bfloat16)
8591
def test_combine_tensors(self, dtype, device="npu"):
8692
x = torch.zeros((2, 2, 2, 2), device=device, dtype=dtype)

torch_npu/csrc/aten/common/ResizeNpu.h

+25-18
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
namespace at_npu {
1515
namespace native {
1616

17+
// Refresh storage_desc to ND if set force_refresh = true,
18+
// mainly used in storage.resize_
1719
static void storage_resize_npu(
1820
torch_npu::NPUStorageImpl& storage,
1921
ptrdiff_t size,
20-
c10::IntArrayRef new_size)
22+
c10::IntArrayRef new_size,
23+
bool force_refresh = false)
2124
{
2225
if (!storage.resizable()) {
2326
TORCH_CHECK(false, "Trying to resize storage that is not resizable", OPS_ERROR(ErrCode::NOT_SUPPORT));
@@ -50,25 +53,29 @@ static void storage_resize_npu(
5053
};
5154
// It is necessary to properly refresh the storage according to sizes and strides,
5255
// not just new sizes.
53-
int64_t new_data_numel = c10::multiply_integers(resize_shape);
54-
int64_t new_shape_numel = c10::multiply_integers(new_size);
55-
const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size;
56-
57-
// 计算连续场景下size对应的stride值
58-
int64_t dim_ = static_cast<int64_t>(refresh_size.size());
59-
c10::SmallVector<int64_t, 5> new_stride(dim_);
60-
if (dim_ > 0) {
61-
int64_t last_idx = dim_ - 1;
62-
new_stride[last_idx] = 1;
63-
for (auto i = last_idx - 1; i >= 0; --i) {
64-
new_stride[i] = new_stride[i + 1] * std::max<int64_t>(refresh_size[i + 1], 1);
56+
if (force_refresh) {
57+
int64_t new_data_numel = c10::multiply_integers(resize_shape);
58+
int64_t new_shape_numel = c10::multiply_integers(new_size);
59+
const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size;
60+
61+
// 计算连续场景下size对应的stride值
62+
int64_t dim_ = static_cast<int64_t>(refresh_size.size());
63+
c10::SmallVector<int64_t, 5> new_stride(dim_);
64+
if (dim_ > 0) {
65+
int64_t last_idx = dim_ - 1;
66+
new_stride[last_idx] = 1;
67+
for (auto i = last_idx - 1; i >= 0; --i) {
68+
new_stride[i] = new_stride[i + 1] * std::max<int64_t>(refresh_size[i + 1], 1);
69+
}
6570
}
66-
}
6771

68-
storage_desc.base_sizes_ = refresh_size;
69-
storage_desc.base_strides_ = new_stride;
70-
storage_desc.npu_format_ = ACL_FORMAT_ND;
71-
storage_desc.storage_sizes_ = storage_desc.base_sizes_;
72+
storage_desc.base_sizes_ = refresh_size;
73+
storage_desc.base_strides_ = new_stride;
74+
storage_desc.npu_format_ = ACL_FORMAT_ND;
75+
storage_desc.storage_sizes_ = storage_desc.base_sizes_;
76+
} else {
77+
StorageDescHelper::UpdateDesc(storage_desc, resize_shape, new_size);
78+
}
7279

7380
if (old_data != nullptr) {
7481
ptrdiff_t copy_size = old_size;

torch_npu/csrc/core/npu/NPUHooksInterface.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void NPUHooksInterface::resizePrivateUse1Bytes(const c10::Storage &storage, size
3838

3939
auto itemsize = storage_impl->npu_desc_.data_type_.itemsize();
4040
std::vector<int64_t> new_size = {static_cast<int64_t>(new_bytes) / (ptrdiff_t)itemsize};
41-
at_npu::native::storage_resize_npu(*storage_impl, new_bytes, new_size);
41+
at_npu::native::storage_resize_npu(*storage_impl, new_bytes, new_size, true);
4242
}
4343

4444
at::PrivateUse1HooksInterface* get_npu_hooks()

0 commit comments

Comments
 (0)