diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3479e1e7c36e..5c112c7dfce0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2697,6 +2697,40 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + scale = inputs[2] + scale_shape = infer_shape(scale) + if len(inputs) == 4: + assert ( + len(scale_shape) == 0 or scale_shape[0] == 0 + ), "One of scale or size should be passed, not both." + size = inputs[3] + else: + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + scale = inputs[2] + size = inputs[3] + if size is not None: + assert scale is None, "One of scale or size should be passed, not both." + else: + scale_type = infer_type(scale) + scale_shape = scale_type.checked_type.shape + scale_dtype = scale_type.checked_type.dtype + assert len(scale_shape) != 0, "One of scale or size should be passed." + size = _op.cast(shape_of(inputs[0]), scale_dtype) * scale + + return cls.v11_13_common(inputs, size, attr, params) + + @classmethod + def v11_13_common(cls, inputs, size, attr, params): + """ + Resize v11 and Resize v13 are identical except in how + they handle the passing of scale and size. This utility + provides the implementation for both + """ ndims = len(infer_shape(inputs[0])) mode = attr.get("mode").decode("ascii") if mode == "nearest": @@ -2715,16 +2749,6 @@ def _impl_v11(cls, inputs, attr, params): alpha = attr.get("cubic_coeff_a", -0.75) exclude = attr.get("exclude_outside", 0) - scale = inputs[2] - scale_shape = infer_shape(scale) - if len(inputs) == 4: - assert ( - len(scale_shape) == 0 or scale_shape[0] == 0 - ), "One of scale or size should be passed, not both." - size = inputs[3] - else: - assert len(scale_shape) != 0, "One of scale or size should be passed." - size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) out = None if ndims == 3: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2301747034dd..dd1c77330986 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3970,6 +3970,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] input_names = ["X", "roi", "scales"] + if oshape != []: nodes.append( make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape) @@ -4954,15 +4955,7 @@ def verify_eyelike(indata): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", - "test_resize_downsample_sizes_cubic", - "test_resize_downsample_sizes_linear_pytorch_half_pixel", - "test_resize_downsample_sizes_nearest", "test_resize_tf_crop_and_resize", - "test_resize_upsample_sizes_cubic", - "test_resize_upsample_sizes_nearest", - "test_resize_upsample_sizes_nearest_ceil_half_pixel", - "test_resize_upsample_sizes_nearest_floor_align_corners", - "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", "test_rnn_seq_length", "test_round", "test_scan9_sum",