diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index fdc17762fd8..1b2276ad7cc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -28,38 +28,23 @@ void resize_max_pool2d_node( size_t ndim = self.sizes().size(); std::vector new_out_sizes(ndim); - // Batch + // Batch, Channel if (ndim == 4) { new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4); } - // Channel new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3); - const auto kernel_size = reverse(*graph, extra_args[0]); - const auto stride = reverse(*graph, extra_args[1]); - const auto padding = reverse(*graph, extra_args[2]); - const auto dilation = reverse(*graph, extra_args[3]); - const bool ceil_mode = graph->get_val(extra_args[4]).toBool(); - - // Height - new_out_sizes.at(ndim - 2) = calc_out_size( - self.sizes().at(ndim - 2), - kernel_size.data[1], - stride.data[1], - padding.data[1], - dilation.data[1], - ceil_mode); - // Width - new_out_sizes.at(ndim - 1) = calc_out_size( - self.sizes().at(ndim - 1), - kernel_size.data[0], - stride.data[0], - padding.data[0], - dilation.data[0], - ceil_mode); - - VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1); - VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1); + // Height, Width + const auto new_out_sizes_hw = calc_out_sizes_hw( + *graph, + self.sizes(), + extra_args[0], + extra_args[1], + extra_args[2], + extra_args[3], + extra_args[4]); + new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0); + new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1); out.virtual_resize(new_out_sizes); indices.virtual_resize(new_out_sizes); @@ -96,12 +81,8 @@ void add_max_pool2d_node( kernel_name << "max_pool2d"; apply_dtype_suffix(kernel_name, t_out); - KernelParams kernel_params{ - reverse(graph, kernel_size), - reverse(graph, stride), - reverse(graph, padding), - reverse(graph, dilation), - }; + KernelParams kernel_params = + create_kernel_params(graph, kernel_size, stride, padding, dilation); graph.execute_nodes().emplace_back(new ExecuteNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp index 86371d3c2d8..d1d006f39f9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp @@ -10,25 +10,80 @@ namespace vkcompute { +api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) { + return api::utils::make_ivec2( + graph.get_val(vref).toIntList(), /*reverse = */ true); +} + +KernelParams create_kernel_params( + ComputeGraph& graph, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + return { + make_ivec2_from_list(graph, kernel_size), + make_ivec2_from_list(graph, stride), + make_ivec2_from_list(graph, padding), + make_ivec2_from_list(graph, dilation), + }; +} + int64_t calc_out_size( const int64_t in_size, - const int64_t kernel, + const int64_t kernel_size, const int64_t stride, const int64_t padding, const int64_t dilation, const bool ceil_mode) { int64_t c = ceil_mode ? stride - 1 : 0; int64_t out_size = - (in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1; + (in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + c) / stride + + 1; if (ceil_mode && (out_size - 1) * stride >= in_size + padding) { --out_size; } return out_size; } -api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) { - return api::utils::make_ivec2( - graph.get_val(vref).toIntList(), /*reverse=*/true); +std::vector calc_out_sizes_hw( + ComputeGraph& graph, + const std::vector& in_sizes, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef ceil_mode) { + const int64_t ndim = in_sizes.size(); + std::vector out_sizes(2); + + const auto kernel_vec = make_ivec2_from_list(graph, kernel_size); + const auto stride_vec = make_ivec2_from_list(graph, stride); + const auto padding_vec = make_ivec2_from_list(graph, padding); + const auto dilation_vec = make_ivec2_from_list(graph, dilation); + const bool ceil_mode_val = graph.get_val(ceil_mode).toBool(); + + // Height + out_sizes.at(0) = calc_out_size( + in_sizes.at(ndim - 2), + kernel_vec.data[1], + stride_vec.data[1], + padding_vec.data[1], + dilation_vec.data[1], + ceil_mode_val); + // Width + out_sizes.at(1) = calc_out_size( + in_sizes.at(ndim - 1), + kernel_vec.data[0], + stride_vec.data[0], + padding_vec.data[0], + dilation_vec.data[0], + ceil_mode_val); + + VK_CHECK_COND(out_sizes.at(0) >= 1); + VK_CHECK_COND(out_sizes.at(1) >= 1); + + return out_sizes; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h index 6e6763dc574..b5e946e9413 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h @@ -23,14 +23,20 @@ struct KernelParams final { api::utils::ivec2 dilation; }; -int64_t calc_out_size( - const int64_t in_size, - const int64_t kernel_size, - const int64_t stride, - const int64_t padding, - const int64_t dilation, - const bool ceil_mode); - -api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref); +KernelParams create_kernel_params( + ComputeGraph& graph, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation); + +std::vector calc_out_sizes_hw( + ComputeGraph& graph, + const std::vector& in_sizes, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef ceil_mode); } // namespace vkcompute