diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index bedd1f1d8fe..60d1982d97e 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -39,10 +39,8 @@ void PrepackNode::encode(ComputeGraph* graph) { TensorRef tref = graph->get_val(tref_).toTensorRef(); vTensor packed = graph->get_val(packed_).toTensor(); - // TODO: Extract to standalone function, to support other types of prepacking. - api::StorageBuffer staging( - graph->context(), packed.dtype(), packed.gpu_nbytes()); size_t numel = api::utils::multiply_integers(tref.sizes); + api::StorageBuffer staging(graph->context(), tref.dtype, numel); size_t nbytes = numel * api::element_size(tref.dtype); copy_ptr_to_staging(tref.data, staging, nbytes);