diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index f1ea1df7ff9..e0d20523986 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -43,16 +43,32 @@ PrepackNode::PrepackNode( graph.update_descriptor_counts(noop_shader_, /*execute = */ false); } -void PrepackNode::encode(ComputeGraph* graph) { - api::Context* const context = graph->context(); - - TensorRef& tref = graph->get_val(tref_).toTensorRef(); +api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { vTensor& packed = graph->get_val(packed_).toTensor(); + // If no TensorRef is provided, create a staging buffer of zeros according to + // the vTensor metadata. + if (graph->get_val(tref_).isNone()) { + size_t numel = api::utils::multiply_integers(packed.sizes()); + api::StorageBuffer staging(graph->context(), packed.dtype(), numel); + size_t nbytes = numel * api::element_size(packed.dtype()); + set_staging_zeros(staging, nbytes); + return staging; + } + + TensorRef& tref = graph->get_val(tref_).toTensorRef(); size_t numel = api::utils::multiply_integers(tref.sizes); api::StorageBuffer staging(graph->context(), tref.dtype, numel); size_t nbytes = numel * api::element_size(tref.dtype); copy_ptr_to_staging(tref.data, staging, nbytes); + return staging; +} + +void PrepackNode::encode(ComputeGraph* graph) { + api::Context* const context = graph->context(); + + vTensor& packed = graph->get_val(packed_).toTensor(); + api::StorageBuffer staging = create_staging_buffer(graph); std::unique_lock cmd_lock = context->dispatch_lock(); @@ -76,7 +92,7 @@ void PrepackNode::encode(ComputeGraph* graph) { } // Submit a compute shader that performs a no-op with the packed tensor in - // order to trigger a image layout transition from GENERAL to + // order to trigger an image layout transition from GENERAL to // READ_ONLY_OPTIMAL. This ensures that future uses of the tensor will be // bound with the correct image layout. { diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index f894e179407..d0ff8afa660 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -48,6 +48,9 @@ class PrepackNode final { const ValueRef packed_; // TODO(T180906457): allow re-computing param buffers. std::vector> params_; + + private: + api::StorageBuffer create_staging_buffer(ComputeGraph* graph); }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp index 5b83e7d31fa..18e7fa95f18 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -52,12 +52,16 @@ void resize_conv2d_node( out.virtual_resize(new_out_sizes); } -ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) { - if (graph.get_val(vref).isNone()) { - VK_THROW("aten.convolution.default: Null bias is not supported yet!"); - } +ValueRef prepack_biases( + ComputeGraph& graph, + const ValueRef vref, + const ValueRef weight, + const bool transposed) { + TensorRef& tref = graph.get_val(weight).toTensorRef(); + const int64_t out_channels = transposed ? tref.sizes.at(1) : tref.sizes.at(0); - ValueRef v = graph.add_tensor_like(vref, api::kTexture2D, api::kWidthPacked); + ValueRef v = graph.add_tensor( + {out_channels}, tref.dtype, api::kTexture2D, api::kWidthPacked); vTensor& t = graph.get_val(v).toTensor(); api::ShaderInfo shader = get_nchw_to_image_shader(t); @@ -296,7 +300,7 @@ void add_conv2d_node( ValueRef arg_in = prepack_if_tensor_ref(graph, in); ValueRef arg_weight = prepack_weights(graph, weight, method); - ValueRef arg_bias = prepack_biases(graph, bias); + ValueRef arg_bias = prepack_biases(graph, bias, weight, transposed_val); vTensor& t_in = graph.get_val(arg_in).toTensor(); vTensor& t_out = graph.get_val(out).toTensor(); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index 94228321f79..383841be75c 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -89,6 +89,12 @@ void copy_staging_to_ptr( memcpy_from_mapping(mapping, dst, nbytes, staging.dtype()); } +void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) { + api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); + uint8_t* data_ptr = mapping.template data(); + memset(data_ptr, 0, staging.nbytes()); +} + api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { if (v_dst.is_quantized()) { VK_THROW("Quantized Tensors are currently not supported!"); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 0634d8d02e7..0bcbff5d74e 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -25,6 +25,8 @@ void copy_staging_to_ptr( void* dst, const size_t nbytes); +void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes); + // // Functions to get shaders // diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8ba695524cd..dd2142eee47 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -601,3 +601,30 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_conv2d_bias_false(self): + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=6, + out_channels=8, + kernel_size=(3, 3), + padding=(2, 3), + stride=(1, 2), + dilation=1, + groups=1, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + + conv2d_module = Conv2dModule() + sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) + + self.lower_module_and_test_output( + conv2d_module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + )