From 4336380fd000aace78781808d3a891304e4e7f7b Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Wed, 3 Apr 2024 12:13:27 -0700 Subject: [PATCH] [ET-VK] Refactor Pool.cpp This change adds more lines than it subtracts, but it'll be worth it once we reuse the methods for `aten.convolution`. Differential Revision: [D55706057](https://our.internmc.facebook.com/intern/diff/D55706057/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/impl/Pool.cpp | 47 +++++--------- .../graph/ops/impl/utils/KernelUtils.cpp | 62 ++++++++++++++++++- .../graph/ops/impl/utils/KernelUtils.h | 17 ++++- 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index fdc17762fd8..2932080092a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -28,38 +28,23 @@ void resize_max_pool2d_node( size_t ndim = self.sizes().size(); std::vector new_out_sizes(ndim); - // Batch + // Batch, Channel if (ndim == 4) { new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4); } - // Channel new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3); - const auto kernel_size = reverse(*graph, extra_args[0]); - const auto stride = reverse(*graph, extra_args[1]); - const auto padding = reverse(*graph, extra_args[2]); - const auto dilation = reverse(*graph, extra_args[3]); - const bool ceil_mode = graph->get_val(extra_args[4]).toBool(); - - // Height - new_out_sizes.at(ndim - 2) = calc_out_size( - self.sizes().at(ndim - 2), - kernel_size.data[1], - stride.data[1], - padding.data[1], - dilation.data[1], - ceil_mode); - // Width - new_out_sizes.at(ndim - 1) = calc_out_size( - self.sizes().at(ndim - 1), - kernel_size.data[0], - stride.data[0], - padding.data[0], - dilation.data[0], - ceil_mode); - - VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1); - VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1); + // Height, Width + const auto hw_sizes = calc_hw_out_sizes( + *graph, + self.sizes(), + extra_args[0], + extra_args[1], + extra_args[2], + extra_args[3], + extra_args[4]); + new_out_sizes.at(ndim - 2) = hw_sizes.at(0); + new_out_sizes.at(ndim - 1) = hw_sizes.at(1); out.virtual_resize(new_out_sizes); indices.virtual_resize(new_out_sizes); @@ -96,12 +81,8 @@ void add_max_pool2d_node( kernel_name << "max_pool2d"; apply_dtype_suffix(kernel_name, t_out); - KernelParams kernel_params{ - reverse(graph, kernel_size), - reverse(graph, stride), - reverse(graph, padding), - reverse(graph, dilation), - }; + KernelParams kernel_params = + create_kernel_params(graph, kernel_size, stride, padding, dilation); graph.execute_nodes().emplace_back(new ExecuteNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp index 86371d3c2d8..e8a507abb5a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp @@ -10,6 +10,25 @@ namespace vkcompute { +api::utils::ivec2 +make_ivec2_int_list(ComputeGraph& graph, ValueRef vref, const bool reverse) { + return api::utils::make_ivec2(graph.get_val(vref).toIntList(), reverse); +} + +KernelParams create_kernel_params( + ComputeGraph& graph, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + return { + make_ivec2_int_list(graph, kernel_size, /*reverse=*/true), + make_ivec2_int_list(graph, stride, /*reverse=*/true), + make_ivec2_int_list(graph, padding, /*reverse=*/true), + make_ivec2_int_list(graph, dilation, /*reverse=*/true), + }; +} + int64_t calc_out_size( const int64_t in_size, const int64_t kernel, @@ -26,9 +45,46 @@ int64_t calc_out_size( return out_size; } -api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) { - return api::utils::make_ivec2( - graph.get_val(vref).toIntList(), /*reverse=*/true); +std::vector calc_hw_out_sizes( + ComputeGraph& graph, + const std::vector& in_sizes, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef ceil_mode) { + const int64_t ndim = in_sizes.size(); + std::vector out_sizes(2); + + const auto kernel_vec = + make_ivec2_int_list(graph, kernel_size, /*reverse=*/false); + const auto stride_vec = make_ivec2_int_list(graph, stride, /*reverse=*/false); + const auto padding_vec = + make_ivec2_int_list(graph, padding, /*reverse=*/false); + const auto dilation_vec = + make_ivec2_int_list(graph, dilation, /*reverse=*/false); + + // Height + out_sizes.at(0) = calc_out_size( + in_sizes.at(ndim - 2), + kernel_vec.data[0], + stride_vec.data[0], + padding_vec.data[0], + dilation_vec.data[0], + ceil_mode); + // Width + out_sizes.at(1) = calc_out_size( + in_sizes.at(ndim - 1), + kernel_vec.data[1], + stride_vec.data[1], + padding_vec.data[1], + dilation_vec.data[1], + ceil_mode); + + VK_CHECK_COND(out_sizes.at(0) >= 1); + VK_CHECK_COND(out_sizes.at(1) >= 1); + + return out_sizes; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h index 6e6763dc574..799c1551a85 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h @@ -23,6 +23,13 @@ struct KernelParams final { api::utils::ivec2 dilation; }; +KernelParams create_kernel_params( + ComputeGraph& graph, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation); + int64_t calc_out_size( const int64_t in_size, const int64_t kernel_size, @@ -31,6 +38,12 @@ int64_t calc_out_size( const int64_t dilation, const bool ceil_mode); -api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref); - +std::vector calc_hw_out_sizes( + ComputeGraph& graph, + const std::vector& in_sizes, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef ceil_mode); } // namespace vkcompute