@@ -318,8 +318,8 @@ api::UniformParamsBuffer make_metadata_uniform(
318318  }
319319
320320  vTensor::BufferMetadata metadata{
321-       api::utils::make_nchw_uvec4  (sizes),
322-       api::utils::make_nchw_uvec4  (strides),
321+       api::utils::make_whcn_uvec4  (sizes),
322+       api::utils::make_whcn_uvec4  (strides),
323323      api::utils::safe_downcast<uint32_t >(sizes.size ()),
324324      api::utils::safe_downcast<uint32_t >(api::utils::multiply_integers (sizes)),
325325  };
@@ -347,12 +347,13 @@ vTensor::vTensor(
347347      strides_{calc_strides (sizes, memory_layout_, storage_type)},
348348      gpu_sizes_{calc_gpu_sizes (sizes, memory_layout_, storage_type)},
349349      gpu_strides_{calc_strides (gpu_sizes_, memory_layout_, storage_type)},
350-       //  Vulkan uniform buffer containing sizes and stride info
351-       metadata_uniform_{make_metadata_uniform (
352-           context,
353-           gpu_sizes_,
354-           gpu_strides_,
355-           storage_type)},
350+       virtual_extents_ (
351+           create_image_extents (gpu_sizes_, storage_type, memory_layout)),
352+       //  Utility Uniform Buffers that can be passed to shaders as arguments
353+       metadata_uniform_(),
354+       cpu_sizes_uniform_(nullptr ),
355+       gpu_sizes_uniform_(nullptr ),
356+       extents_uniform_(nullptr ),
356357      //  Construct Tensor storage
357358      view_(std::make_shared<vTensorStorage>(
358359          context,
@@ -377,12 +378,13 @@ vTensor::vTensor(
377378      strides_{calc_strides (sizes, memory_layout_, storage_type)},
378379      gpu_sizes_{calc_gpu_sizes (sizes, memory_layout_, storage_type)},
379380      gpu_strides_{calc_strides (gpu_sizes_, memory_layout_, storage_type)},
381+       virtual_extents_ (
382+           create_image_extents (gpu_sizes_, storage_type, memory_layout)),
380383      //  Vulkan uniform buffer containing sizes and stride info
381-       metadata_uniform_{make_metadata_uniform (
382-           context,
383-           gpu_sizes_,
384-           gpu_strides_,
385-           storage_type)},
384+       metadata_uniform_(),
385+       cpu_sizes_uniform_(nullptr ),
386+       gpu_sizes_uniform_(nullptr ),
387+       extents_uniform_(nullptr ),
386388      //  Quantization params
387389      is_quantized_{true },
388390      q_scale_{q_scale},
@@ -425,10 +427,47 @@ api::VulkanBuffer& vTensor::buffer(
425427  return  view_->buffer_ ;
426428}
427429
430+ api::VulkanBuffer& vTensor::buffer_metadata () {
431+   if  (!metadata_uniform_.buffer ()) {
432+     metadata_uniform_ = make_metadata_uniform (
433+         view_->context_ , gpu_sizes_, gpu_strides_, storage_type ());
434+   }
435+   return  metadata_uniform_.buffer ();
436+ }
437+ 
438+ std::shared_ptr<api::UniformParamsBuffer> vTensor::cpu_sizes_ubo () {
439+   if  (!cpu_sizes_uniform_) {
440+     cpu_sizes_uniform_.reset (new  api::UniformParamsBuffer (
441+         view_->context_ , api::utils::make_whcn_ivec4 (sizes_)));
442+   }
443+   return  cpu_sizes_uniform_;
444+ }
445+ 
446+ std::shared_ptr<api::UniformParamsBuffer> vTensor::gpu_sizes_ubo () {
447+   if  (!gpu_sizes_uniform_) {
448+     gpu_sizes_uniform_.reset (new  api::UniformParamsBuffer (
449+         view_->context_ , api::utils::make_whcn_ivec4 (gpu_sizes_)));
450+   }
451+   return  gpu_sizes_uniform_;
452+ }
453+ 
454+ std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo () {
455+   if  (!extents_uniform_) {
456+     extents_uniform_.reset (new  api::UniformParamsBuffer (
457+         view_->context_ ,
458+         api::utils::uvec4 (
459+             {view_->extents_ .data [0 ],
460+              view_->extents_ .data [1 ],
461+              view_->extents_ .data [2 ],
462+              1u })));
463+   }
464+   return  extents_uniform_;
465+ }
466+ 
428467vTensor::BufferMetadata vTensor::get_cpu_buffer_metadata () const  {
429468  return  {
430-       api::utils::make_nchw_uvec4  (sizes_),
431-       api::utils::make_nchw_uvec4  (strides_),
469+       api::utils::make_whcn_uvec4  (sizes_),
470+       api::utils::make_whcn_uvec4  (strides_),
432471      api::utils::safe_downcast<uint32_t >(sizes_.size ()),
433472      api::utils::safe_downcast<uint32_t >(
434473          api::utils::multiply_integers (sizes_)),
@@ -473,6 +512,65 @@ void vTensor::bind_allocation(const api::MemoryAllocation& allocation) {
473512  }
474513}
475514
515+ void  vTensor::update_size_metadata (const  std::vector<int64_t >& new_sizes) {
516+   sizes_ = new_sizes;
517+   gpu_sizes_ = calc_gpu_sizes (sizes_, memory_layout_, storage_type ());
518+   virtual_extents_ =
519+       create_image_extents (gpu_sizes_, storage_type (), memory_layout_);
520+ 
521+   if  (cpu_sizes_uniform_) {
522+     cpu_sizes_uniform_->update (api::utils::make_whcn_ivec4 (sizes_));
523+   }
524+ 
525+   if  (gpu_sizes_uniform_) {
526+     gpu_sizes_uniform_->update (api::utils::make_whcn_ivec4 (gpu_sizes_));
527+   }
528+ 
529+   if  (extents_uniform_) {
530+     extents_uniform_->update (api::utils::uvec4 (
531+         {virtual_extents_.data [0 ],
532+          virtual_extents_.data [1 ],
533+          virtual_extents_.data [2 ],
534+          1u }));
535+   }
536+ }
537+ 
538+ void  vTensor::reallocate (const  std::vector<int64_t >& new_sizes) {
539+   update_size_metadata (new_sizes);
540+   view_->discard_and_reallocate (
541+       calc_gpu_sizes (new_sizes, memory_layout_, storage_type ()),
542+       memory_layout_,
543+       dtype_);
544+ }
545+ 
546+ void  vTensor::virtual_resize (const  std::vector<int64_t >& new_sizes) {
547+   update_size_metadata (new_sizes);
548+   if  (storage_type () == api::StorageType::BUFFER) {
549+     if  (gpu_nbytes () > view_->buffer_ .mem_size ()) {
550+       VK_THROW (
551+           " Cannot virtual_resize a vTensor with sizes that require a larger " 
552+           " buffer! reallocate() should be used instead." 
553+     }
554+   } else  {
555+     bool  valid_resize = true ;
556+     if  (virtual_extents_.data [0 ] > view_->extents_ .data [0 ]) {
557+       valid_resize = false ;
558+     }
559+     if  (virtual_extents_.data [1 ] > view_->extents_ .data [1 ]) {
560+       valid_resize = false ;
561+     }
562+     if  (virtual_extents_.data [2 ] > view_->extents_ .data [2 ]) {
563+       valid_resize = false ;
564+     }
565+ 
566+     if  (!valid_resize) {
567+       VK_THROW (
568+           " Cannot virtual_resize a vTensor with sizes that require a larger " 
569+           " image texture! reallocate() should be used instead." 
570+     }
571+   }
572+ }
573+ 
476574// 
477575//  vTensorStorage
478576// 
@@ -569,11 +667,16 @@ vTensorStorage::vTensorStorage(
569667      last_access_{} {}
570668
571669vTensorStorage::~vTensorStorage () {
670+   flush ();
671+ }
672+ 
673+ void  vTensorStorage::flush () {
572674  if  (image_) {
573675    context_->register_image_cleanup (image_);
574676  } else  if  (buffer_) {
575677    context_->register_buffer_cleanup (buffer_);
576678  }
679+   last_access_ = {};
577680}
578681
579682void  vTensorStorage::transition (
@@ -663,6 +766,28 @@ void add_buffer_barrier(
663766  }
664767}
665768
769+ void  vTensorStorage::discard_and_reallocate (
770+     const  std::vector<int64_t >& gpu_sizes,
771+     const  api::GPUMemoryLayout gpu_memory_layout,
772+     const  api::ScalarType dtype) {
773+   const  bool  image_owns_memory = image_.owns_memory ();
774+   const  bool  buffer_owns_memory = buffer_.owns_memory ();
775+ 
776+   flush ();
777+ 
778+   extents_ = create_image_extents (gpu_sizes, storage_type_, gpu_memory_layout);
779+   image_ = allocate_image (
780+       context_,
781+       extents_,
782+       storage_type_,
783+       api::to_vkformat (dtype),
784+       image_owns_memory);
785+ 
786+   buffer_length_ = api::utils::multiply_integers (gpu_sizes);
787+   buffer_ = allocate_buffer (
788+       context_, buffer_length_, storage_type_, dtype, buffer_owns_memory);
789+ }
790+ 
666791} //  namespace vulkan
667792} //  namespace native
668793} //  namespace at
0 commit comments