|
14 | 14 | namespace at_npu {
|
15 | 15 | namespace native {
|
16 | 16 |
|
| 17 | +// Refresh storage_desc to ND if set force_refresh = true, |
| 18 | +// mainly used in storage.resize_ |
17 | 19 | static void storage_resize_npu(
|
18 | 20 | torch_npu::NPUStorageImpl& storage,
|
19 | 21 | ptrdiff_t size,
|
20 |
| - c10::IntArrayRef new_size) |
| 22 | + c10::IntArrayRef new_size, |
| 23 | + bool force_refresh = false) |
21 | 24 | {
|
22 | 25 | if (!storage.resizable()) {
|
23 | 26 | TORCH_CHECK(false, "Trying to resize storage that is not resizable", OPS_ERROR(ErrCode::NOT_SUPPORT));
|
@@ -50,25 +53,29 @@ static void storage_resize_npu(
|
50 | 53 | };
|
51 | 54 | // It is necessary to properly refresh the storage according to sizes and strides,
|
52 | 55 | // not just new sizes.
|
53 |
| - int64_t new_data_numel = c10::multiply_integers(resize_shape); |
54 |
| - int64_t new_shape_numel = c10::multiply_integers(new_size); |
55 |
| - const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size; |
56 |
| - |
57 |
| - // 计算连续场景下size对应的stride值 |
58 |
| - int64_t dim_ = static_cast<int64_t>(refresh_size.size()); |
59 |
| - c10::SmallVector<int64_t, 5> new_stride(dim_); |
60 |
| - if (dim_ > 0) { |
61 |
| - int64_t last_idx = dim_ - 1; |
62 |
| - new_stride[last_idx] = 1; |
63 |
| - for (auto i = last_idx - 1; i >= 0; --i) { |
64 |
| - new_stride[i] = new_stride[i + 1] * std::max<int64_t>(refresh_size[i + 1], 1); |
| 56 | + if (force_refresh) { |
| 57 | + int64_t new_data_numel = c10::multiply_integers(resize_shape); |
| 58 | + int64_t new_shape_numel = c10::multiply_integers(new_size); |
| 59 | + const c10::IntArrayRef &refresh_size = new_data_numel > new_shape_numel ? resize_shape : new_size; |
| 60 | + |
| 61 | + // 计算连续场景下size对应的stride值 |
| 62 | + int64_t dim_ = static_cast<int64_t>(refresh_size.size()); |
| 63 | + c10::SmallVector<int64_t, 5> new_stride(dim_); |
| 64 | + if (dim_ > 0) { |
| 65 | + int64_t last_idx = dim_ - 1; |
| 66 | + new_stride[last_idx] = 1; |
| 67 | + for (auto i = last_idx - 1; i >= 0; --i) { |
| 68 | + new_stride[i] = new_stride[i + 1] * std::max<int64_t>(refresh_size[i + 1], 1); |
| 69 | + } |
65 | 70 | }
|
66 |
| - } |
67 | 71 |
|
68 |
| - storage_desc.base_sizes_ = refresh_size; |
69 |
| - storage_desc.base_strides_ = new_stride; |
70 |
| - storage_desc.npu_format_ = ACL_FORMAT_ND; |
71 |
| - storage_desc.storage_sizes_ = storage_desc.base_sizes_; |
| 72 | + storage_desc.base_sizes_ = refresh_size; |
| 73 | + storage_desc.base_strides_ = new_stride; |
| 74 | + storage_desc.npu_format_ = ACL_FORMAT_ND; |
| 75 | + storage_desc.storage_sizes_ = storage_desc.base_sizes_; |
| 76 | + } else { |
| 77 | + StorageDescHelper::UpdateDesc(storage_desc, resize_shape, new_size); |
| 78 | + } |
72 | 79 |
|
73 | 80 | if (old_data != nullptr) {
|
74 | 81 | ptrdiff_t copy_size = old_size;
|
|
0 commit comments