Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNNE-1904 Support coordinate_mode and nearest_mode selection for nearest resize_image #1065

Merged
merged 13 commits into from
Aug 28, 2023
Merged
1 change: 1 addition & 0 deletions src/Native/include/nncase/kernels/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cassert>
#include <cmath>
#include <cstddef>
#include <nncase/kernels/stackvm/resize_image.h>
#include <nncase/runtime/datatypes.h>
#include <numeric>

Expand Down
94 changes: 94 additions & 0 deletions src/Native/include/nncase/kernels/stackvm/resize_image.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright 2019-2023 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <nncase/runtime/stackvm/opcode.h>

using namespace nncase::runtime::stackvm;

using get_coordinate_func_t = float (*)(float, float, float, float, float,
float);
using get_nearest_pixel_func_t = int64_t (*)(float);

get_coordinate_func_t get_coordinate_from_resized(
image_resize_transformation_mode_t coordinate_transform_mode);

get_nearest_pixel_func_t
get_nearest_pixel_from_origin(image_resize_nearest_mode_t nearest_mode);

inline get_coordinate_func_t get_coordinate_from_resized(
image_resize_transformation_mode_t coordinate_transform_mode) {
switch (coordinate_transform_mode) {
case image_resize_transformation_mode_t::asymmetric:
return [](float x_resized, float x_scale, float, float, float, float) {
return x_resized * x_scale;
};
case image_resize_transformation_mode_t::pytorch_half_pixel:
return [](float x_resized, float x_scale, float length_resized, float,
float, float) {
return length_resized > 1 ? (x_resized + 0.5f) * x_scale - 0.5f
: 0.0f;
};
case image_resize_transformation_mode_t::align_corners:
return [](float x_resized, float, float length_resized,
float length_original, float, float) {
return length_resized == 1 ? 0
: x_resized * (length_original - 1) /
(length_resized - 1);
};
case image_resize_transformation_mode_t::tfcrop_and_resize:
return [](float x_resized, float, float length_resized,
float length_original, float roi_start, float roi_end) {
auto orig =
length_resized > 1
? roi_start * (length_original - 1) +
(x_resized * (roi_end - roi_start) *
(length_original - 1)) /
(length_resized - 1)
: 0.5 * (roi_start + roi_end) * (length_original - 1);
return static_cast<float>(orig);
};
default: // "image_resize_transformation_mode_t::half_pixel"
return [](float x_resized, float x_scale, float, float, float, float) {
return ((x_resized + 0.5f) * x_scale) - 0.5f;
};
}
}

inline get_nearest_pixel_func_t
get_nearest_pixel_from_origin(image_resize_nearest_mode_t nearest_mode) {
switch (nearest_mode) {
case image_resize_nearest_mode_t::round_prefer_ceil:
return [](float x_original) {
return static_cast<int64_t>(std::round(x_original));
};
case image_resize_nearest_mode_t::floor:
return [](float x_original) {
return static_cast<int64_t>(std::floor(x_original));
};
case image_resize_nearest_mode_t::ceil:
return [](float x_original) {
return static_cast<int64_t>(std::ceil(x_original));
};
default: // default is round_prefer_floor
return [](float x_original) {
// for half way cases prefer floor
if (x_original == static_cast<int64_t>(x_original) + 0.5f) {
return static_cast<int64_t>(std::floor(x_original));
}
return static_cast<int64_t>(std::round(x_original));
};
}
}
4 changes: 3 additions & 1 deletion src/Native/src/kernels/stackvm/optimized/opt_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
*/
#pragma once
#include <nncase/kernels/kernel_context.h>
#include <nncase/kernels/kernel_utils.h>
#include <nncase/runtime/datatypes.h>
#include <nncase/runtime/error.h>
#include <nncase/runtime/result.h>
#include <nncase/runtime/stackvm/opcode.h>
#include <nncase/tensor.h>
#include <nncase/value.h>

BEGIN_NS_NNCASE_KERNELS_MODULE(stackvm)
namespace optimized {

Expand Down Expand Up @@ -111,6 +111,8 @@ NNCASE_API result<void> resize_nearest_neighbor(
gsl::span<const size_t> in_shape, gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int32_t out_h, int32_t out_w,
bool align_corners, bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
kernel_context &context) noexcept;

NNCASE_API result<void>
Expand Down
36 changes: 25 additions & 11 deletions src/Native/src/kernels/stackvm/optimized/resize_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ result<void> resize_nearest_neighbor_impl(
const T *input, T *output, gsl::span<const size_t> in_shape,
NNCASE_UNUSED gsl::span<const size_t> in_strides,
NNCASE_UNUSED gsl::span<const size_t> out_strides, int32_t out_h,
int32_t out_w, bool align_corners, bool half_pixel_centers,
int32_t out_w, NNCASE_UNUSED bool align_corners,
NNCASE_UNUSED bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
NNCASE_UNUSED kernel_context &context) noexcept {
auto scales = kernels::detail::get_resize_scales(in_shape, out_h, out_w,
align_corners);
Expand All @@ -110,15 +113,23 @@ result<void> resize_nearest_neighbor_impl(
auto *output_ptr = begin_output_ptr + oc * out_image_size;

for (int oy = 0; oy < out_h; oy++) {
auto in_y = kernels::detail::get_nearest_neighbor(
oy, in_shape[2], height_scale, align_corners,
half_pixel_centers);
auto iy = get_coordinate_func(oy, height_scale, out_h,
in_shape[2], 0, 0);
int64_t in_y = get_nearset_func(iy);
if (in_y < 0)
in_y = 0;
if (in_y >= in_shape[2])
in_y = in_shape[2] - 1;
auto *in_row = input_ptr + in_y * in_shape[3];

for (int ox = 0; ox < out_w; ox++) {
auto in_x = kernels::detail::get_nearest_neighbor(
ox, in_shape[3], width_scale, align_corners,
half_pixel_centers);
auto ix = get_coordinate_func(ox, width_scale, out_w,
in_shape[3], 0, 0);
int64_t in_x = get_nearset_func(ix);
if (in_x < 0)
in_x = 0;
if (in_x >= in_shape[3])
in_x = in_shape[3] - 1;
*output_ptr++ = in_row[in_x];
}
}
Expand Down Expand Up @@ -264,10 +275,11 @@ inline result<void> resize_bilinear_impl(
half_pixel_centers, context);

#define RESIZE_NEAREST_NEIGHBOR_IMPL(type) \
resize_nearest_neighbor_impl(reinterpret_cast<const type *>(input), \
reinterpret_cast<type *>(output), in_shape, \
in_strides, out_strides, out_h, out_w, \
align_corners, half_pixel_centers, context);
resize_nearest_neighbor_impl( \
reinterpret_cast<const type *>(input), \
reinterpret_cast<type *>(output), in_shape, in_strides, out_strides, \
out_h, out_w, align_corners, half_pixel_centers, get_coordinate_func, \
get_nearset_func, context);

result<void> optimized::resize_bilinear(
typecode_t type, const gsl::byte *input, gsl::byte *output,
Expand All @@ -283,6 +295,8 @@ result<void> optimized::resize_nearest_neighbor(
gsl::span<const size_t> in_shape, gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int32_t out_h, int32_t out_w,
bool align_corners, bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
kernel_context &context) noexcept {
FP_OR_Q_IMPL(type, RESIZE_NEAREST_NEIGHBOR_IMPL);
}
3 changes: 3 additions & 0 deletions src/Native/src/kernels/stackvm/reference/ref_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once
#include <nncase/kernels/apply.h>
#include <nncase/kernels/kernel_context.h>
#include <nncase/kernels/kernel_utils.h>
#include <nncase/runtime/datatypes.h>
#include <nncase/runtime/error.h>
#include <nncase/runtime/result.h>
Expand Down Expand Up @@ -345,6 +346,8 @@ NNCASE_API result<void> resize_nearest_neighbor(
gsl::span<const size_t> in_shape, gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int32_t out_h, int32_t out_w,
bool align_corners, bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
kernel_context &context) noexcept;

NNCASE_API result<void> reverse_sequence(
Expand Down
36 changes: 25 additions & 11 deletions src/Native/src/kernels/stackvm/reference/resize_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ template <class T>
result<void> resize_nearest_neighbor_impl(
const T *input, T *output, gsl::span<const size_t> in_shape,
gsl::span<const size_t> in_strides, gsl::span<const size_t> out_strides,
int32_t out_h, int32_t out_w, bool align_corners, bool half_pixel_centers,
int32_t out_h, int32_t out_w, NNCASE_UNUSED bool align_corners,
NNCASE_UNUSED bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
NNCASE_UNUSED kernel_context &context) noexcept {
auto scales = kernels::detail::get_resize_scales(in_shape, out_h, out_w,
align_corners);
Expand All @@ -106,16 +109,24 @@ result<void> resize_nearest_neighbor_impl(
in_index[1] = oc;
out_index[1] = oc;
for (size_t oy = 0; oy < (size_t)out_h; oy++) {
auto in_y = kernels::detail::get_nearest_neighbor(
oy, in_shape[2], height_scale, align_corners,
half_pixel_centers);
auto iy = get_coordinate_func(oy, height_scale, out_h,
in_shape[2], 0, 0);
int64_t in_y = get_nearset_func(iy);
if (in_y < 0)
in_y = 0;
if (in_y >= in_shape[2])
in_y = in_shape[2] - 1;
in_index[2] = in_y;
out_index[2] = oy;

for (size_t ox = 0; ox < (size_t)out_w; ox++) {
auto in_x = kernels::detail::get_nearest_neighbor(
ox, in_shape[3], width_scale, align_corners,
half_pixel_centers);
auto ix = get_coordinate_func(ox, width_scale, out_w,
in_shape[3], 0, 0);
int64_t in_x = get_nearset_func(ix);
if (in_x < 0)
in_x = 0;
if (in_x >= in_shape[3])
in_x = in_shape[3] - 1;
in_index[3] = in_x;
out_index[3] = ox;
output[offset(out_strides, out_index)] =
Expand Down Expand Up @@ -154,10 +165,11 @@ result<void> resize_nearest_neighbor_impl(
half_pixel_centers, context);

#define RESIZE_NEAREST_NEIGHBOR_IMPL(type) \
resize_nearest_neighbor_impl(reinterpret_cast<const type *>(input), \
reinterpret_cast<type *>(output), in_shape, \
in_strides, out_strides, out_h, out_w, \
align_corners, half_pixel_centers, context);
resize_nearest_neighbor_impl( \
reinterpret_cast<const type *>(input), \
reinterpret_cast<type *>(output), in_shape, in_strides, out_strides, \
out_h, out_w, align_corners, half_pixel_centers, get_coordinate_func, \
get_nearset_func, context);
} // namespace

result<void> nncase::kernels::stackvm::reference::resize_bilinear(
Expand All @@ -174,6 +186,8 @@ result<void> nncase::kernels::stackvm::reference::resize_nearest_neighbor(
gsl::span<const size_t> in_shape, gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int32_t out_h, int32_t out_w,
bool align_corners, bool half_pixel_centers,
get_coordinate_func_t get_coordinate_func,
get_nearest_pixel_func_t get_nearset_func,
kernel_context &context) noexcept {
FP_OR_Q_IMPL(type, RESIZE_NEAREST_NEIGHBOR_IMPL);
}
6 changes: 5 additions & 1 deletion src/Native/src/kernels/stackvm/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,11 +841,15 @@ result<value_t> nncase::kernels::stackvm::resize_image(
new_size_value[2], new_size_value[3], align_corner,
half_pixel, context);
} else if (resize_mode == image_resize_mode_t::nearest_neighbor) {
auto get_coordinate_func =
get_coordinate_from_resized(transformation_mode);
auto get_nearset_func = get_nearest_pixel_from_origin(nearest_mode);
CONTIGUOUS_KERNEL(resize_nearest_neighbor, input_tensor, tycode, in_mem,
out_mem, input_tensor->shape(),
input_tensor->strides(), output_tensor->strides(),
new_size_value[2], new_size_value[3], align_corner,
half_pixel, context);
half_pixel, get_coordinate_func, get_nearset_func,
context);
} else {
return err(nncase_errc::runtime_not_found);
}
Expand Down
Loading