diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 0024c676efa..44676f2c236 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -407,8 +407,6 @@ - op: upsample_bilinear2d.vec_out -- op: upsample_nearest2d.out - - op: upsample_nearest2d.vec_out - op: var.correction_out diff --git a/kernels/portable/cpu/op_upsample_nearest2d.cpp b/kernels/portable/cpu/op_upsample_nearest2d.cpp new file mode 100644 index 00000000000..93a88588d83 --- /dev/null +++ b/kernels/portable/cpu/op_upsample_nearest2d.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::SizesType; + +namespace { +template +void upsample_nearest2d_kernel_impl( + const Tensor& in, + const float scale_h, + const float scale_w, + Tensor& out) { + const auto in_data = in.const_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + auto in_plane = in_data; + for (auto n = 0; n < out.size(0); n++) { + for (auto c = 0; c < out.size(1); c++) { + for (auto h = 0; h < out.size(2); h++) { + for (auto w = 0; w < out.size(3); w++) { + const auto in_h = + nearest_neighbor_compute_source_index(scale_h, h, in.sizes()[2]); + const auto in_w = + nearest_neighbor_compute_source_index(scale_w, w, in.sizes()[3]); + + *out_data = in_plane[in_h * in.strides()[2] + in_w * in.strides()[3]]; + out_data++; + } + } + + in_plane += in.strides()[1]; + } + } +} +} // namespace + +Tensor& upsample_nearest2d_vec_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out) { + // Preconditions (checked in check_..._args): + // In and out tensors have same dtype. + // In and out tensors are rank 4 and have same dim[0] and dim[1]. + // In and out tensors are default dim order (NCHW). + ET_KERNEL_CHECK( + ctx, + check_upsample_nearest2d_args(in, output_size, scale_factors, out), + InvalidArgument, + out); + + double scale_h, scale_w; + + ET_KERNEL_CHECK_MSG( + ctx, + resize_upsample_2d( + in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor"); + + const auto kernel_scale_h = area_pixel_compute_scale( + in.sizes()[2], out.sizes()[2], false, scale_h); + const auto kernel_scale_w = area_pixel_compute_scale( + in.sizes()[3], out.sizes()[3], false, scale_w); + + ET_SWITCH_REAL_TYPES( + in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() { + upsample_nearest2d_kernel_impl( + in, kernel_scale_h, kernel_scale_w, out); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/upsample_util.cpp b/kernels/portable/cpu/util/upsample_util.cpp index 2082b632d98..15cf28b663d 100644 --- a/kernels/portable/cpu/util/upsample_util.cpp +++ b/kernels/portable/cpu/util/upsample_util.cpp @@ -46,6 +46,14 @@ bool check_upsample_bilinear2d_args( return check_upsample_2d_common_args(in, output_size, scale_factors, out); } +bool check_upsample_nearest2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out) { + return check_upsample_2d_common_args(in, output_size, scale_factors, out); +} + Error resize_upsample_2d( const Tensor& in, const exec_aten::OptionalArrayRef& output_size, diff --git a/kernels/portable/cpu/util/upsample_util.h b/kernels/portable/cpu/util/upsample_util.h index 86ae8e77ecf..785b6f1cc87 100644 --- a/kernels/portable/cpu/util/upsample_util.h +++ b/kernels/portable/cpu/util/upsample_util.h @@ -28,6 +28,12 @@ bool check_upsample_bilinear2d_args( const exec_aten::OptionalArrayRef& scale_factors, Tensor& out); +bool check_upsample_nearest2d_args( + const Tensor& in, + const exec_aten::OptionalArrayRef& output_size, + const exec_aten::OptionalArrayRef& scale_factors, + Tensor& out); + Error resize_upsample_2d( const Tensor& in, const exec_aten::OptionalArrayRef& output_size, @@ -127,5 +133,17 @@ inline void compute_source_index_and_lambda( } } +// Ported from aten/src/ATen/native/UpSample.h +inline int64_t nearest_neighbor_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // Index computation matching OpenCV INTER_NEAREST + // which is buggy and kept for BC + const int64_t src_index = + std::min(static_cast(floorf(dst_index * scale)), input_size - 1); + return src_index; +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index ed203f9487b..96382eb497b 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -922,6 +922,11 @@ - arg_meta: null kernel_name: torch::executor::upsample_bilinear2d_vec_out +- op: upsample_nearest2d.vec_out + kernels: + - arg_meta: null + kernel_name: torch::executor::upsample_nearest2d_vec_out + - op: var.correction_out kernels: - arg_meta: null diff --git a/kernels/portable/test/TARGETS b/kernels/portable/test/TARGETS index 6255c37cbc8..adf6636be4f 100644 --- a/kernels/portable/test/TARGETS +++ b/kernels/portable/test/TARGETS @@ -21,6 +21,7 @@ runtime.cxx_library( deps = [ "//executorch/extension/aten_util:aten_bridge", "//executorch/kernels/portable/cpu:op_upsample_bilinear2d", + "//executorch/kernels/portable/cpu:op_upsample_nearest2d", "//executorch/runtime/core/exec_aten:lib", ], external_deps = [ @@ -40,3 +41,16 @@ python_unittest( "//caffe2:torch", ], ) + +python_unittest( + name = "op_upsample_nearest2d_test", + srcs = [ + "op_upsample_nearest2d_test.py", + ], + preload_deps = [ + ":aot_ops_test_lib", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/kernels/portable/test/op_upsample_nearest2d_test.py b/kernels/portable/test/op_upsample_nearest2d_test.py new file mode 100644 index 00000000000..23a6bb54fbb --- /dev/null +++ b/kernels/portable/test/op_upsample_nearest2d_test.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import itertools +import unittest + +from typing import Optional, Sequence + +import torch + + +class UpsampleNearest2dTest(unittest.TestCase): + def run_upsample_test( + self, + inp: torch.Tensor, + output_size: Optional[Sequence[int]] = None, + scale_factors: Optional[Sequence[float]] = None, + atol=1e-7, + ) -> None: + aten_result = torch.nn.functional.interpolate( + inp, + size=output_size, + mode="nearest", + scale_factor=scale_factors, + ) + et_result = torch.zeros_like(aten_result) + et_result = torch.ops.et_test.upsample_nearest2d( + inp, + output_size=output_size, + scale_factors=scale_factors, + out=et_result, + ) + self.assertTrue( + torch.allclose(et_result, aten_result, atol=atol), + msg=f"ET: {et_result} \n ATen: {aten_result} \n Error: {et_result.to(torch.float) - aten_result.to(torch.float)}", + ) + + def test_upsample_nearest2d_aten_parity_f32(self): + N = [1, 2] + C = [1, 3] + H = [1, 3, 50, 1001] + W = [1, 2, 62, 1237] + OUT_H = [5, 21] + OUT_W = [7, 31] + + for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W): + input = torch.randn(n, c, h, w) + self.run_upsample_test(input, output_size=(out_h, out_w)) + self.run_upsample_test(input, scale_factors=(out_h / h, out_w / w)) + + def test_upsample_nearest2d_aten_parity_u8(self): + N = [1, 2] + C = [1, 3] + H = [1, 3, 50, 1001] + W = [1, 2, 62, 1237] + OUT_H = [5, 21] + OUT_W = [7, 31] + + for n, c, h, w, out_h, out_w in itertools.product(N, C, H, W, OUT_H, OUT_W): + input = torch.randint(0, 255, (n, c, h, w), dtype=torch.uint8) + self.run_upsample_test(input, output_size=(out_h, out_w), atol=1) + self.run_upsample_test( + input, + scale_factors=(out_h / h, out_w / w), + atol=2, + ) diff --git a/kernels/portable/test/register_ops_aot_for_test.cpp b/kernels/portable/test/register_ops_aot_for_test.cpp index c8b63e736a0..40d4bb79ed1 100644 --- a/kernels/portable/test/register_ops_aot_for_test.cpp +++ b/kernels/portable/test/register_ops_aot_for_test.cpp @@ -47,6 +47,31 @@ Tensor& upsample_bilinear2d_vec_out_no_context( return ret; } + +Tensor& upsample_nearest2d_vec_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out); + +Tensor& upsample_nearest2d_vec_out_no_context( + const Tensor& in, + const exec_aten::OptionalArrayRef output_size, + const exec_aten::OptionalArrayRef scale_factors, + Tensor& out) { + KernelRuntimeContext ctx; + auto& ret = + upsample_nearest2d_vec_out(ctx, in, output_size, scale_factors, out); + + if (ctx.failure_state() != Error::Ok) { + throw std::runtime_error( + std::string("Kernel failed with error: ") + + std::to_string((int)ctx.failure_state())); + } + + return ret; +} // NOLINTEND(facebook-hte-ConstantArgumentPassByValue, // facebook-hte-ParameterMightThrowOnCopy) @@ -54,6 +79,9 @@ TORCH_LIBRARY(et_test, m) { m.def( "upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(upsample_bilinear2d_vec_out_no_context, 4)); + m.def( + "upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", + WRAP_TO_ATEN(upsample_nearest2d_vec_out_no_context, 3)); } } // namespace native diff --git a/kernels/test/op_upsample_nearest2d_test.cpp b/kernels/test/op_upsample_nearest2d_test.cpp new file mode 100644 index 00000000000..12301688002 --- /dev/null +++ b/kernels/test/op_upsample_nearest2d_test.cpp @@ -0,0 +1,408 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +#include + +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::SupportedFeatures; +using torch::executor::testing::TensorFactory; + +#ifdef USE_ATEN_LIB +template +using OptionalArrayRef = std::optional>; +#else +using exec_aten::OptionalArrayRef; +#endif + +class OpUpsampleNearest2dTest : public OperatorTest { + protected: + Tensor& op_upsample_nearest2d_out( + const Tensor& in, + const OptionalArrayRef output_size, + const OptionalArrayRef scale_factors, + Tensor& out) { + return torch::executor::aten::upsample_nearest2d_outf( + context_, in, output_size, scale_factors, out); + } + + template + void test_upsample_nearest2d_dtype() { + TensorFactory tf; + + const auto input = tf.make({1, 1, 2, 2}, {1, 2, 3, 4}); + std::array output_size = {4, 4}; + auto out = tf.zeros({1, 1, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + true, + {}, + out); + + const auto expected = + tf.make({1, 1, 4, 4}, {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4}); + + EXPECT_TENSOR_CLOSE(out, expected); + } +}; + +TEST_F(OpUpsampleNearest2dTest, SmokeTest) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + std::array output_size = {4, 4}; + auto out = tf.zeros({1, 1, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {1, 1, 4, 4}, + { + 0.1, + 0.1, + 0.2, + 0.2, + 0.1, + 0.1, + 0.2, + 0.2, + 1.1, + 1.1, + 1.2, + 1.2, + 1.1, + 1.1, + 1.2, + 1.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, SmokeTestScale) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + auto out = tf.zeros({1, 1, 4, 4}); + std::array scale_factors = {2, 2}; + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.make( + {1, 1, 4, 4}, + { + 0.1, + 0.1, + 0.2, + 0.2, + 0.1, + 0.1, + 0.2, + 0.2, + 1.1, + 1.1, + 1.2, + 1.2, + 1.1, + 1.1, + 1.2, + 1.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, UpsampleSimpleFractional) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + std::array output_size = {5, 9}; + auto out = tf.zeros({1, 1, 5, 9}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {1, 1, 5, 9}, {0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, + 0.2, 0.2, 0.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2, + 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, UpsampleSimpleFractionalScale) { + TensorFactory tf; + + const auto input = tf.make( + {1, 1, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + }); + auto out = tf.zeros({1, 1, 5, 9}); + std::array scale_factors = {5 / 2.0, 9 / 2.0}; + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.make( + {1, 1, 5, 9}, {0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, + 0.2, 0.2, 0.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2, + 1.1, 1.1, 1.1, 1.1, 1.1, 1.2, 1.2, 1.2, 1.2}); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, MultiBatchAndChannel) { + TensorFactory tf; + const auto input = tf.make( + {2, 2, 2, 2}, + { + 0.1, + 0.2, + 1.1, + 1.2, + 2.1, + 2.2, + 3.1, + 3.2, + 4.1, + 4.2, + 5.1, + 5.2, + 6.1, + 6.2, + 7.1, + 7.2, + }); + std::array output_size = {4, 4}; + auto out = tf.zeros({2, 2, 4, 4}); + + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out); + + const auto expected = tf.make( + {2, 2, 4, 4}, + { + 0.1, 0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.2, 1.1, 1.1, 1.2, 1.2, 1.1, + 1.1, 1.2, 1.2, 2.1, 2.1, 2.2, 2.2, 2.1, 2.1, 2.2, 2.2, 3.1, 3.1, + 3.2, 3.2, 3.1, 3.1, 3.2, 3.2, 4.1, 4.1, 4.2, 4.2, 4.1, 4.1, 4.2, + 4.2, 5.1, 5.1, 5.2, 5.2, 5.1, 5.1, 5.2, 5.2, 6.1, 6.1, 6.2, 6.2, + 6.1, 6.1, 6.2, 6.2, 7.1, 7.1, 7.2, 7.2, 7.1, 7.1, 7.2, 7.2, + }); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(OpUpsampleNearest2dTest, DType) { +#define TEST_ENTRY(ctype, dtype) \ + test_upsample_nearest2d_dtype(); \ + ET_FORALL_REAL_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpUpsampleNearest2dTest, MismatchedOutputSizeDies) { + if (SupportedFeatures::get()->output_resize) { + GTEST_SKIP() + << "The current kernel supports implicitly resizing output tensor"; + } + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 5}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, InvalidInputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, InvalidOutputRankDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, MissingOutputSizeOrScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 2}); + auto out = tf.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_upsample_nearest2d_out(input, {}, {}, out)); +} + +TEST_F(OpUpsampleNearest2dTest, BothOutputSizeAndScaleDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array output_size = {1, 4}; + std::array scale_factors = {1, 2}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + OptionalArrayRef( + {scale_factors.data(), scale_factors.size()}), + out)); +} + +TEST_F(OpUpsampleNearest2dTest, MismatchedDTypeDies) { + TensorFactory tf; + TensorFactory tf2; + + const auto input = tf.ones({1, 1, 2}); + std::array output_size = {1, 4}; + auto out = tf2.zeros({1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + OptionalArrayRef({output_size.data(), output_size.size()}), + {}, + out)); +} + +TEST_F(OpUpsampleNearest2dTest, ComputedOutputSizeMatchesExpected) { + // Computed output sizes (from input size * scales) must match PyTorch + // eager-mode - multiplied as double and cast (truncated) to an integral type. + // See + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/UpSample.cpp + TensorFactory tf; + + // Test case format: { in_h, in_w, scale_h, scale_w, out_h, out_w } + std::vector> + test_cases = { + {10, 10, 9.99999, 9.55, 99, 95}, + {10, 10, 9.99999999, 0.1, 99, 1}, + }; + + for (const auto& test_case : test_cases) { + const auto [in_h, in_w, scale_h, scale_w, out_h, out_w] = test_case; + + const auto input = tf.ones({1, 1, in_h, in_w}); + std::array scale_factors = {scale_h, scale_w}; + auto out = tf.zeros({1, 1, out_h, out_w}); + + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef({scale_factors.data(), scale_factors.size()}), + out); + + const auto expected = tf.ones({1, 1, out_h, out_w}); + + EXPECT_TENSOR_EQ(out, expected); + } +} + +TEST_F(OpUpsampleNearest2dTest, ZeroComputedOutputSizeDies) { + TensorFactory tf; + + const auto input = tf.ones({1, 1, 1, 2}); + std::array scale_factors = {1, 0.25}; + auto out = tf.zeros({1, 1, 1, 4}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_upsample_nearest2d_out( + input, + {}, + OptionalArrayRef( + {scale_factors.data(), scale_factors.size()}), + out)); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index d4f5769dbe6..db14173050b 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -347,6 +347,7 @@ def define_common_targets(): _common_op_test("op_unbind_copy_test", ["aten", "portable"]) _common_op_test("op_unsqueeze_copy_test", ["aten", "portable"]) _common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"]) + _common_op_test("op_upsample_nearest2d_test", ["aten", "portable"]) _common_op_test("op_var_test", ["aten", "portable"]) _common_op_test("op_view_copy_test", ["aten", "portable"]) _common_op_test("op_where_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 9f2a3834027..f5ddae06b6a 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1235,6 +1235,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:upsample_util", ], ), + op_target( + name = "op_upsample_nearest2d", + deps = [ + "//executorch/kernels/portable/cpu/util:upsample_util", + ], + ), op_target( name = "op_var", deps = [