Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 14 additions & 33 deletions backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,23 @@ void resize_max_pool2d_node(
size_t ndim = self.sizes().size();
std::vector<int64_t> new_out_sizes(ndim);

// Batch
// Batch, Channel
if (ndim == 4) {
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
}
// Channel
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);

const auto kernel_size = reverse(*graph, extra_args[0]);
const auto stride = reverse(*graph, extra_args[1]);
const auto padding = reverse(*graph, extra_args[2]);
const auto dilation = reverse(*graph, extra_args[3]);
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();

// Height
new_out_sizes.at(ndim - 2) = calc_out_size(
self.sizes().at(ndim - 2),
kernel_size.data[1],
stride.data[1],
padding.data[1],
dilation.data[1],
ceil_mode);
// Width
new_out_sizes.at(ndim - 1) = calc_out_size(
self.sizes().at(ndim - 1),
kernel_size.data[0],
stride.data[0],
padding.data[0],
dilation.data[0],
ceil_mode);

VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
// Height, Width
const auto hw_sizes = calc_hw_out_sizes(
*graph,
self.sizes(),
extra_args[0],
extra_args[1],
extra_args[2],
extra_args[3],
extra_args[4]);
new_out_sizes.at(ndim - 2) = hw_sizes.at(0);
new_out_sizes.at(ndim - 1) = hw_sizes.at(1);

out.virtual_resize(new_out_sizes);
indices.virtual_resize(new_out_sizes);
Expand Down Expand Up @@ -96,12 +81,8 @@ void add_max_pool2d_node(
kernel_name << "max_pool2d";
apply_dtype_suffix(kernel_name, t_out);

KernelParams kernel_params{
reverse(graph, kernel_size),
reverse(graph, stride),
reverse(graph, padding),
reverse(graph, dilation),
};
KernelParams kernel_params =
create_kernel_params(graph, kernel_size, stride, padding, dilation);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
Expand Down
62 changes: 59 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@

namespace vkcompute {

api::utils::ivec2
make_ivec2_int_list(ComputeGraph& graph, ValueRef vref, const bool reverse) {
return api::utils::make_ivec2(graph.get_val(vref).toIntList(), reverse);
}

KernelParams create_kernel_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation) {
return {
make_ivec2_int_list(graph, kernel_size, /*reverse=*/true),
make_ivec2_int_list(graph, stride, /*reverse=*/true),
make_ivec2_int_list(graph, padding, /*reverse=*/true),
make_ivec2_int_list(graph, dilation, /*reverse=*/true),
};
}

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel,
Expand All @@ -26,9 +45,46 @@ int64_t calc_out_size(
return out_size;
}

api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
return api::utils::make_ivec2(
graph.get_val(vref).toIntList(), /*reverse=*/true);
std::vector<int64_t> calc_hw_out_sizes(
ComputeGraph& graph,
const std::vector<int64_t>& in_sizes,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef ceil_mode) {
const int64_t ndim = in_sizes.size();
std::vector<int64_t> out_sizes(2);

const auto kernel_vec =
make_ivec2_int_list(graph, kernel_size, /*reverse=*/false);
const auto stride_vec = make_ivec2_int_list(graph, stride, /*reverse=*/false);
const auto padding_vec =
make_ivec2_int_list(graph, padding, /*reverse=*/false);
const auto dilation_vec =
make_ivec2_int_list(graph, dilation, /*reverse=*/false);

// Height
out_sizes.at(0) = calc_out_size(
in_sizes.at(ndim - 2),
kernel_vec.data[0],
stride_vec.data[0],
padding_vec.data[0],
dilation_vec.data[0],
ceil_mode);
// Width
out_sizes.at(1) = calc_out_size(
in_sizes.at(ndim - 1),
kernel_vec.data[1],
stride_vec.data[1],
padding_vec.data[1],
dilation_vec.data[1],
ceil_mode);

VK_CHECK_COND(out_sizes.at(0) >= 1);
VK_CHECK_COND(out_sizes.at(1) >= 1);

return out_sizes;
}

} // namespace vkcompute
17 changes: 15 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ struct KernelParams final {
api::utils::ivec2 dilation;
};

KernelParams create_kernel_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation);

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel_size,
Expand All @@ -31,6 +38,12 @@ int64_t calc_out_size(
const int64_t dilation,
const bool ceil_mode);

api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);

std::vector<int64_t> calc_hw_out_sizes(
ComputeGraph& graph,
const std::vector<int64_t>& in_sizes,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef ceil_mode);
} // namespace vkcompute