From e45ffb5e87d5c9c312b124fc135f121f2a31c01d Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Mon, 30 Jan 2023 16:49:18 +0300 Subject: [PATCH 01/23] init convertor for DFT --- python/tvm/relay/frontend/onnx.py | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2a1890627225..c78db7c2299c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4877,6 +4877,37 @@ def _impl_v1(cls, inputs, attr, params): return mm_out +class DFT(OnnxOpConverter): + """Operator converter for discrete Fourier transform (DFT).""" + + @classmethod + def _impl_v17(cls, inputs, attr, params): + # ************************* Read attrs ************************* + axis = attr.get("axis") + inverse = attr.get("inverse") + onesided = attr.get("onesided") + + # ************************* Read inputs ************************ + input_tensor = inputs[0] + dft_length = inputs[1] + + # ************************* Parse inputs *********************** + t1 = ["float16", "float32", "float64"] + t2 = ["int32", "int64"] + + # input + assert infer_type(input_tensor).checked_type.dtype in t1 + input_shape = infer_shape(input_tensor) + assert len(input_shape) >= 3 + n = len(input_shape) - 2 + + # dft_length + if dft_length is not None: + raise NotImplementedError("dft_length") + + raise NotImplementedError("DFT") + + class NonMaxSuppression(OnnxOpConverter): """Operator converter for NonMaxSuppression.""" @@ -6696,6 +6727,7 @@ def _get_convert_map(opset): "Scan": Scan.get_converter(opset), # ML "LinearRegressor": LinearRegressor.get_converter(opset), + "DFT": DFT.get_converter(opset), # Sequence operators "SequenceConstruct": SequenceConstruct.get_converter(opset), "SequenceEmpty": SequenceEmpty.get_converter(opset), From 4a554610dc7207d716c85e234ff25c8bd3ce5d3a Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Mon, 30 Jan 2023 16:50:41 +0300 Subject: [PATCH 02/23] init test for DFT --- tests/python/frontend/onnx/test_forward.py | 77 ++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 293f4d38e649..dcd5de948862 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7933,6 +7933,83 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1): verify_linear_regressor((1, 4), (3), (1)) +@tvm.testing.parametrize_targets +def test_dft(target, dev): + """test_dft""" + def verify_dft( + _axis, + _inverse, + _onesided, + _input, + _dft_length=None, + ): + input_names = ["input"] + if _dft_length is not None: + input_names.append("_dft_length") + + node = onnx.helper.make_node( + "DFT", + inputs=input_names, + outputs=["output"], + # domain="com.microsoft", + axis=_axis, + inverse=_inverse, + onesided=_onesided, + ) + + inputs_info = [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), + ] + if _dft_length is not None: + inputs_info.append( + helper.make_tensor_value_info( + "dft_length", TensorProto.INT32, [] + ) + ) + + graph = helper.make_graph( + [node], + "dft_test", + inputs=inputs_info, + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape), + ], + ) + + model = helper.make_model(graph, producer_name="dft_test") + + inputs = [_input] + if _dft_length is not None: + inputs.append(_dft_length) + + verify_with_ort_with_inputs( + model, + inputs, + [input_shape], + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) + + axis = 1 + inverse = 0 + onesided = 0 + + batch_size = 1 + n = 3 + D = 7 + + input_shape = [batch_size] + n * [D] + [1] + output_shape = [batch_size] + n * [D] + [2] + if onesided == 1: + output_shape[axis] = output_shape[axis] // 2 + 1 + + input_tensor = np.random.normal(size=input_shape).astype("float32") + + verify_dft(axis, inverse, onesided, input_tensor) + + @tvm.testing.parametrize_targets def test_sequence(target, dev): """test_sequence""" From 72b6530b726cf19c6b0188a470aab38996fdbbde Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 1 Feb 2023 12:21:27 +0300 Subject: [PATCH 03/23] init DFT operator in Relay --- include/tvm/relay/attrs/transform.h | 22 ++++++++++++ python/tvm/relay/frontend/onnx.py | 5 ++- python/tvm/relay/op/_transform.py | 42 ++++++++++++++++++++++ python/tvm/relay/op/strategy/generic.py | 30 ++++++++++++++++ python/tvm/relay/op/transform.py | 5 +++ python/tvm/topi/stft.py | 9 +++++ src/relay/op/tensor/transform.cc | 47 +++++++++++++++++++++++++ 7 files changed, 159 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b5333961ebf9..d39e3d267f88 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -605,6 +605,28 @@ struct StftAttrs : public tvm::AttrsNode { } }; // struct StftAttrs +/*! \brief Attributes used in DFT operator */ +struct DFTAttrs : public tvm::AttrsNode { + Integer axis; + Bool inverse = Bool(false); + Bool onesided = Bool(false); + + TVM_DECLARE_ATTRS(DFTAttrs, "relay.attrs.DFTAttrs") { + TVM_ATTR_FIELD(axis) + .describe("The axis on which to perform the DFT") + .set_default(1); + TVM_ATTR_FIELD(inverse) + .describe("Whether to perform the inverse discrete fourier transform") + .set_default(Bool(false)); + TVM_ATTR_FIELD(onesided) + .describe( + "If onesided is True, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] " + "are returned because the real-to-complex Fourier transform satisfies the conjugate " + "symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*") + .set_default(Bool(false)); + } +}; // struct DFTAttrs + struct TriluAttrs : public tvm::AttrsNode { bool upper; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c78db7c2299c..f7d9be1796e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4905,7 +4905,10 @@ def _impl_v17(cls, inputs, attr, params): if dft_length is not None: raise NotImplementedError("dft_length") - raise NotImplementedError("DFT") + assert inverse == 0, "inverse not supported" + assert onesided == 0, "onesided not supported" + + return _op.dft(input_tensor, axis, inverse, onesided) class NonMaxSuppression(OnnxOpConverter): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e40179ed2d03..06be646ffd87 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -191,6 +191,48 @@ def stft_shape_func(attrs, inputs, _): ] +# DFT +@_reg.register_compute("dft") +def compute_dft(attrs, inputs, output_type): + """Compute definition of DFT""" + # TODO(agladyshev): output_type not used, add dft_length? + return topi.dft( + inputs[0], + attrs.axis, + attrs.inverse, + attrs.onesided, + ) + + +_reg.register_strategy("dft", strategy.dft_strategy) + + +@script +def _dft_shape_func(data, axis, onesided): + n = len(data.shape) + output_shape = output_tensor((n,), "int64") + for i in const_range(n): + output_shape[i] = int64(data.shape[i]) + output_shape[n - 1] = int64(2) + + if onesided: + output_shape[axis] = int64(output_shape[axis] // int64(2)) + int64(1) + + return output_shape + + +@_reg.register_shape_func("dft", True) +def dft_shape_func(attrs, inputs, _): + """ + Shape func for DFT. + """ + return [ + _dft_shape_func( + inputs[0], convert(attrs.axis), convert(attrs.onesided) + ) + ] + + # trilu _reg.register_strategy("trilu", strategy.trilu_strategy) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4e0448f1799b..5e55f9b536e7 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1468,6 +1468,36 @@ def _compute_stft(attrs, inputs, output_type): return _compute_stft +# dft +@override_native_generic_func("dft_strategy") +def dft_strategy(attrs, outs, out_type, target): + """DFT generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dft(topi.dft), + wrap_topi_schedule(topi.generic.schedule_extern), + name="dft.generic", + ) + return strategy + + +def wrap_compute_dft(topi_compute): + """Wrap DFT compute""" + + # TODO(agladyshev): output_type not used + def _compute_dft(attrs, inputs, output_type): + return [ + topi_compute( + inputs[0], + attrs.axis, + attrs.inverse, + attrs.onesided, + ) + ] + + return _compute_dft + + # trilu @override_native_generic_func("trilu_strategy") def trilu_strategy(attrs, outs, out_type, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3df13da04426..35dc1f82fe73 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1987,6 +1987,11 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) +# TODO(agladyshev): add description, dft_length? +def dft(data, axis=1, inverse=False, onesided=False): + return _make.dft(data, axis, inverse, onesided) + + def trilu(data, k, upper=True): """Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor. diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index b59c0245a052..2c6fa285e828 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -123,3 +123,12 @@ def gen_ir( name="stft_cpu", tag="stft_cpu", ) + + +def dft( + data: te.Tensor, + axis: int, + inverse: bool, + onesided: bool, +): + raise NotImplementedError(data) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 01e5a7f5f359..40934f2e33bc 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1922,6 +1922,53 @@ RELAY_REGISTER_OP("stft") .set_support_level(3) .set_attr("TOpPattern", kOpaque); +// DFT +TVM_REGISTER_NODE_TYPE(DFTAttrs); +bool DFTRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // types: [data, output] + // TODO(agladyshev): add support for dft_length input? + ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; + const auto* data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "DFT: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto* param = attrs.as(); + + std::vector output_shape(data->shape.begin(), data->shape.end()); + output_shape[data->shape.size() - 1] = 2; + + if (param->onesided) { + output_shape[param->axis.IntValue()] = floordiv(output_shape[param->axis.IntValue()], 2) + 1; + } + + reporter->Assign(types[1], TensorType({output_shape}, data->dtype)); + + return true; +} + +Expr MakeDFT(Expr data, int axis, bool inverse, bool onesided) { + auto attrs = make_object(); + attrs->axis = axis; + attrs->inverse = Bool(inverse); + attrs->onesided = Bool(onesided); + static const Op& op = Op::Get("dft"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.dft").set_body_typed(MakeDFT); + +RELAY_REGISTER_OP("dft") + .describe( + R"doc(Computes the discrete Fourier transform of input.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("input", "Tensor", "The input tensor.") + .set_support_level(3) + .set_attr("TOpPattern", kOpaque) + .add_type_rel("DFT", DFTRel); + // meshgrid operator TVM_REGISTER_NODE_TYPE(MeshgridAttrs); From df9dfc34093e84c3145071643f295f5c4e65f970 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 15 Feb 2023 13:32:38 +0300 Subject: [PATCH 04/23] update topi implementation for DFT --- include/tvm/relay/attrs/transform.h | 26 +++++----- python/tvm/relay/op/_transform.py | 45 +++++++---------- python/tvm/relay/op/strategy/generic.py | 5 +- python/tvm/relay/op/transform.py | 6 +-- python/tvm/topi/stft.py | 56 +++++++++++++++++++-- src/relay/op/tensor/transform.cc | 45 ++++++++--------- tests/python/topi/python/test_topi_dft.py | 59 +++++++++++++++++++++++ 7 files changed, 171 insertions(+), 71 deletions(-) create mode 100644 tests/python/topi/python/test_topi_dft.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d39e3d267f88..d46413d12bc1 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -607,23 +607,27 @@ struct StftAttrs : public tvm::AttrsNode { /*! \brief Attributes used in DFT operator */ struct DFTAttrs : public tvm::AttrsNode { - Integer axis; +// Integer n_fft; +// Integer axis; Bool inverse = Bool(false); - Bool onesided = Bool(false); +// Bool onesided = Bool(false); TVM_DECLARE_ATTRS(DFTAttrs, "relay.attrs.DFTAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis on which to perform the DFT") - .set_default(1); +// TVM_ATTR_FIELD(n_fft) +// .describe("The size of Fourier transform") +// .set_default(-1); +// TVM_ATTR_FIELD(axis) +// .describe("The axis on which to perform the DFT") +// .set_default(1); TVM_ATTR_FIELD(inverse) .describe("Whether to perform the inverse discrete fourier transform") .set_default(Bool(false)); - TVM_ATTR_FIELD(onesided) - .describe( - "If onesided is True, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] " - "are returned because the real-to-complex Fourier transform satisfies the conjugate " - "symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*") - .set_default(Bool(false)); +// TVM_ATTR_FIELD(onesided) +// .describe( +// "If onesided is True, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] " +// "are returned because the real-to-complex Fourier transform satisfies the conjugate " +// "symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*") +// .set_default(Bool(false)); } }; // struct DFTAttrs diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 06be646ffd87..3c2fd7d41436 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -195,42 +195,33 @@ def stft_shape_func(attrs, inputs, _): @_reg.register_compute("dft") def compute_dft(attrs, inputs, output_type): """Compute definition of DFT""" - # TODO(agladyshev): output_type not used, add dft_length? + # TODO(agladyshev): output_type not used return topi.dft( inputs[0], - attrs.axis, + inputs[1], attrs.inverse, - attrs.onesided, ) _reg.register_strategy("dft", strategy.dft_strategy) -@script -def _dft_shape_func(data, axis, onesided): - n = len(data.shape) - output_shape = output_tensor((n,), "int64") - for i in const_range(n): - output_shape[i] = int64(data.shape[i]) - output_shape[n - 1] = int64(2) - - if onesided: - output_shape[axis] = int64(output_shape[axis] // int64(2)) + int64(1) - - return output_shape - - -@_reg.register_shape_func("dft", True) -def dft_shape_func(attrs, inputs, _): - """ - Shape func for DFT. - """ - return [ - _dft_shape_func( - inputs[0], convert(attrs.axis), convert(attrs.onesided) - ) - ] +# TODO(agladyshev): remove? +# @script +# def _dft_shape_func(re_data, im_data): +# return (re_data.shape, im_data.shape) +# +# +# @_reg.register_shape_func("dft", True) +# def dft_shape_func(attrs, inputs, _): +# """ +# Shape func for DFT. +# """ +# return [ +# _dft_shape_func( +# inputs[0], inputs[1], +# ) +# ] # trilu diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 5e55f9b536e7..6448964a0987 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1485,13 +1485,12 @@ def wrap_compute_dft(topi_compute): """Wrap DFT compute""" # TODO(agladyshev): output_type not used - def _compute_dft(attrs, inputs, output_type): + def _compute_dft(attrs, inputs, _): return [ topi_compute( inputs[0], - attrs.axis, + inputs[1], attrs.inverse, - attrs.onesided, ) ] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 35dc1f82fe73..3b3fab522484 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1987,9 +1987,9 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -# TODO(agladyshev): add description, dft_length? -def dft(data, axis=1, inverse=False, onesided=False): - return _make.dft(data, axis, inverse, onesided) +# TODO(agladyshev): add description +def dft(re_data, im_data, inverse=False): + return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) def trilu(data, k, upper=True): diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 2c6fa285e828..54cf9e8ca34b 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -126,9 +126,55 @@ def gen_ir( def dft( - data: te.Tensor, - axis: int, - inverse: bool, - onesided: bool, + re_data: te.Tensor, + im_data: te.Tensor, + inverse: tir.IntImm, ): - raise NotImplementedError(data) + def gen_ir( + re_data_buf, + im_data_buf, + re_output_buf, + im_output_buf, + ): + ib = tir.ir_builder.create() + re_data_ptr = ib.buffer_ptr(re_data_buf) + im_data_ptr = ib.buffer_ptr(im_data_buf) + re_output_ptr = ib.buffer_ptr(re_output_buf) + im_output_ptr = ib.buffer_ptr(im_output_buf) + + shape = re_data.shape + n_fft = shape[len(shape) - 1] + base_range = 1 + for i in range(len(shape) - 1): + base_range *= shape[i] + + with ib.for_range( + 0, base_range, kind="parallel" + ) as i: + base_idx = i * n_fft + with ib.for_range(0, n_fft) as n: + n_idx = base_idx + n + re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0) + im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) + with ib.for_range(0, n_fft) as k: + k_idx = base_idx + k + w = -2 * pi * k * n / n_fft + cos_w = tir.cos(w) + sin_w = tir.sin(w) + re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w + im_output_ptr[n_idx] += re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w + + return ib.get() + + output_shape = [re_data.shape] * 2 + + return te.extern( + shape=output_shape, + inputs=[re_data, im_data], + fcompute=lambda ins, outs: gen_ir( + ins[0], ins[1], outs[0], outs[1] + ), + dtype=[re_data.dtype, im_data.dtype], + name="dft_cpu", + tag="dft_cpu", + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 40934f2e33bc..e746fc9e86ab 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1925,37 +1925,37 @@ RELAY_REGISTER_OP("stft") // DFT TVM_REGISTER_NODE_TYPE(DFTAttrs); bool DFTRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types: [data, output] - // TODO(agladyshev): add support for dft_length input? - ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output"; - const auto* data = types[0].as(); - if (data == nullptr) { + // types: [re_data, im_data, output] + ICHECK_EQ(types.size(), 3) << "DFT: expects three types, two for the input and one for the output"; + ICHECK_EQ(num_inputs, 2) << "DFT: expect 2 inputs but " << num_inputs << " provided"; + const auto* re_data = types[0].as(); + const auto* im_data = types[1].as(); + + if (re_data == nullptr) { ICHECK(types[0].as()) - << "DFT: expect input type to be TensorType but get " << types[0]; + << "DFT: expect re_data type to be TensorType but get " << types[0]; return false; } - - const auto* param = attrs.as(); - - std::vector output_shape(data->shape.begin(), data->shape.end()); - output_shape[data->shape.size() - 1] = 2; - - if (param->onesided) { - output_shape[param->axis.IntValue()] = floordiv(output_shape[param->axis.IntValue()], 2) + 1; + if (im_data == nullptr) { + ICHECK(types[1].as()) + << "DFT: expect im_data type to be TensorType but get " << types[1]; + return false; } - reporter->Assign(types[1], TensorType({output_shape}, data->dtype)); + std::vector shapes; + shapes.push_back(TensorType(re_data->shape, re_data->dtype)); + shapes.push_back(TensorType(im_data->shape, im_data->dtype)); + + reporter->Assign(types[2], TupleType(Array(shapes))); return true; } -Expr MakeDFT(Expr data, int axis, bool inverse, bool onesided) { +Expr MakeDFT(Expr re_data, Expr im_data, Bool inverse) { auto attrs = make_object(); - attrs->axis = axis; - attrs->inverse = Bool(inverse); - attrs->onesided = Bool(onesided); + attrs->inverse = inverse; static const Op& op = Op::Get("dft"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {re_data, im_data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.dft").set_body_typed(MakeDFT); @@ -1963,8 +1963,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.dft").set_body_typed(MakeDFT); RELAY_REGISTER_OP("dft") .describe( R"doc(Computes the discrete Fourier transform of input.)doc" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("input", "Tensor", "The input tensor.") + .set_num_inputs(2) + .add_argument("re_data", "Tensor", "Real part of input tensor.") + .add_argument("im_data", "Tensor", "Imaginary part of input tensor.") .set_support_level(3) .set_attr("TOpPattern", kOpaque) .add_type_rel("DFT", DFTRel); diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py new file mode 100644 index 000000000000..ce626b3d8dac --- /dev/null +++ b/tests/python/topi/python/test_topi_dft.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +def dft(): + shape = (3, 7, 7) + dtype = "float32" + dev = tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0) + target = "llvm" + + Re = tvm.te.placeholder(shape, dtype=dtype, name="Re") + Im = tvm.te.placeholder(shape, dtype=dtype, name="Im") + + with tvm.target.Target(target): + fcompute = lambda re_x, im_x: topi.dft(re_x, im_x, inverse=False) + fschedule = lambda outs: tvm.te.create_schedule([x.op for x in outs]) + + outs = fcompute(Re, Im) + s = fschedule(outs) + + print(tvm.lower(s, [Re, Im, *outs], simple_mode=False)) + f = tvm.build(s, [Re, Im, *outs], target) + + re_np = np.random.normal(size=shape).astype(dtype) + im_np = np.random.normal(size=shape).astype(dtype) + + re = tvm.nd.array(re_np, device=dev) + im = tvm.nd.array(im_np, device=dev) + re_out = tvm.nd.array(np.zeros(shape).astype(dtype), device=dev) + im_out = tvm.nd.array(np.zeros(shape).astype(dtype), device=dev) + + f(re, im, re_out, im_out) + + ref_dft = np.fft.fft(re_np + 1j * im_np) + tvm.testing.assert_allclose(re_out.numpy(), np.real(ref_dft), rtol=5e-4) + tvm.testing.assert_allclose(im_out.numpy(), np.imag(ref_dft), rtol=5e-4) + + +if __name__ == '__main__': + dft() From 3b061603e10994cf9c506d5bd22b95e385438d16 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 15 Feb 2023 13:39:16 +0300 Subject: [PATCH 05/23] clean up --- include/tvm/relay/attrs/transform.h | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d46413d12bc1..eb22c70098c9 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -607,27 +607,12 @@ struct StftAttrs : public tvm::AttrsNode { /*! \brief Attributes used in DFT operator */ struct DFTAttrs : public tvm::AttrsNode { -// Integer n_fft; -// Integer axis; Bool inverse = Bool(false); -// Bool onesided = Bool(false); TVM_DECLARE_ATTRS(DFTAttrs, "relay.attrs.DFTAttrs") { -// TVM_ATTR_FIELD(n_fft) -// .describe("The size of Fourier transform") -// .set_default(-1); -// TVM_ATTR_FIELD(axis) -// .describe("The axis on which to perform the DFT") -// .set_default(1); TVM_ATTR_FIELD(inverse) - .describe("Whether to perform the inverse discrete fourier transform") + .describe("Whether to perform the inverse discrete Fourier transform") .set_default(Bool(false)); -// TVM_ATTR_FIELD(onesided) -// .describe( -// "If onesided is True, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] " -// "are returned because the real-to-complex Fourier transform satisfies the conjugate " -// "symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*") -// .set_default(Bool(false)); } }; // struct DFTAttrs From 3c10eb61ca19a4f1ce71be31e1ab433c6cd06f97 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 15 Feb 2023 13:40:32 +0300 Subject: [PATCH 06/23] update ONNX frontend --- python/tvm/relay/frontend/onnx.py | 62 ++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f7d9be1796e7..2924c2d0cd3c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4900,6 +4900,7 @@ def _impl_v17(cls, inputs, attr, params): input_shape = infer_shape(input_tensor) assert len(input_shape) >= 3 n = len(input_shape) - 2 + assert 1 <= axis <= len(input_shape) - 1 # dft_length if dft_length is not None: @@ -4908,7 +4909,66 @@ def _impl_v17(cls, inputs, attr, params): assert inverse == 0, "inverse not supported" assert onesided == 0, "onesided not supported" - return _op.dft(input_tensor, axis, inverse, onesided) + # ************************ + swap_axis = -1 + re_input_tensor, im_input_tensor = cls._split_real_and_imag_parts(input_tensor) + + re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis) + im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis) + + re_input_tensor, im_input_tensor = _op.dft(re_input_tensor, im_input_tensor, inverse) + + re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis) + im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis) + + output = cls._merge_real_and_imag_parts(re_input_tensor, im_input_tensor) + + return output + + @classmethod + def _maybe_crop_or_pad(cls, input_tensor, axis, n_fft): + if input_tensor.shape[axis] != n_fft: + s = list(input_tensor.shape) + index = [slice(None)] * len(s) + if s[axis] > n_fft: + index[axis] = slice(0, n_fft) + _input_tensor = input_tensor[tuple(index)] + else: + index[axis] = slice(0, s[axis]) + s[axis] = n_fft + z = np.zeros(s, input_tensor.dtype.char) + z[tuple(index)] = input_tensor + _input_tensor = z + return input_tensor + + @classmethod + def _swap_axes(cls, tensor, axis1, axis2): + permutation = list(range(len(infer_shape(tensor)))) + permutation[axis1] = axis2 + permutation[axis2] = axis1 + return _op.transpose(tensor, permutation) + + @classmethod + def _split_real_and_imag_parts(cls, tensor): + shape = infer_shape(tensor) + dtype = infer_type(tensor).checked_type.dtype + if shape[-1] == 1: + re = tensor + im = _op.const(np.zeros(shape), dtype=dtype) + else: + re, im = _op.split(tensor, 2, -1) + + re = _op.squeeze(re, -1) + im = _op.squeeze(im, -1) + + return re, im + + @classmethod + def _merge_real_and_imag_parts(cls, re, im): + re = _op.expand_dims(re, axis=-1) + im = _op.expand_dims(im, axis=-1) + output = _op.concatenate([re, im], axis=-1) + return output class NonMaxSuppression(OnnxOpConverter): From 554044001015f4f3e352deef8236b6e207a03d05 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 15 Feb 2023 14:47:53 +0300 Subject: [PATCH 07/23] support attribute --- python/tvm/topi/stft.py | 8 +++++++- tests/python/topi/python/test_topi_dft.py | 24 +++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 54cf9e8ca34b..efc01cb71cf5 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -148,6 +148,9 @@ def gen_ir( for i in range(len(shape) - 1): base_range *= shape[i] + sign = -1 if inverse else 1 + factor = 1. / n_fft if inverse else 1. + with ib.for_range( 0, base_range, kind="parallel" ) as i: @@ -158,12 +161,15 @@ def gen_ir( im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) with ib.for_range(0, n_fft) as k: k_idx = base_idx + k - w = -2 * pi * k * n / n_fft + w = sign * -2 * pi * k * n / n_fft cos_w = tir.cos(w) sin_w = tir.sin(w) re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w im_output_ptr[n_idx] += re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w + re_output_ptr[n_idx] *= tir.Cast(re_output_ptr.dtype, factor) + im_output_ptr[n_idx] *= tir.Cast(im_output_ptr.dtype, factor) + return ib.get() output_shape = [re_data.shape] * 2 diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index ce626b3d8dac..e96862a67aaa 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -21,17 +21,20 @@ import tvm.topi.testing -def dft(): - shape = (3, 7, 7) - dtype = "float32" - dev = tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0) - target = "llvm" +def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): + if inverse: + reference = np.fft.ifft(re + 1j * im) + else: + reference = np.fft.fft(re + 1j * im) + return np.real(reference), np.imag(reference) + +def dft(inverse, shape, dtype, dev, target): Re = tvm.te.placeholder(shape, dtype=dtype, name="Re") Im = tvm.te.placeholder(shape, dtype=dtype, name="Im") with tvm.target.Target(target): - fcompute = lambda re_x, im_x: topi.dft(re_x, im_x, inverse=False) + fcompute = lambda re_x, im_x: topi.dft(re_x, im_x, inverse=inverse) fschedule = lambda outs: tvm.te.create_schedule([x.op for x in outs]) outs = fcompute(Re, Im) @@ -50,10 +53,11 @@ def dft(): f(re, im, re_out, im_out) - ref_dft = np.fft.fft(re_np + 1j * im_np) - tvm.testing.assert_allclose(re_out.numpy(), np.real(ref_dft), rtol=5e-4) - tvm.testing.assert_allclose(im_out.numpy(), np.imag(ref_dft), rtol=5e-4) + re_reference, im_reference = numpy_reference(inverse, re_np, im_np) + tvm.testing.assert_allclose(re_out.numpy(), re_reference, rtol=1e-3) + tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=1e-3) if __name__ == '__main__': - dft() + dft(False, (3, 7, 7), "float32", tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0), "llvm") + dft(True, (3, 7, 7), "float32", tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0), "llvm") From 93b3faba1e85df33049a2f9a08058d2059ecf12c Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 12:30:54 +0300 Subject: [PATCH 08/23] fix error: Expected Array[Tensor], but got Array[index 0: Array] --- python/tvm/relay/op/strategy/generic.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6448964a0987..4811cae2ab7f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1484,15 +1484,12 @@ def dft_strategy(attrs, outs, out_type, target): def wrap_compute_dft(topi_compute): """Wrap DFT compute""" - # TODO(agladyshev): output_type not used def _compute_dft(attrs, inputs, _): - return [ - topi_compute( - inputs[0], - inputs[1], - attrs.inverse, - ) - ] + return topi_compute( + inputs[0], + inputs[1], + attrs.inverse, + ) return _compute_dft From 911c9634d63aeadb546a4d845f010ab4b84faa50 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 17:21:14 +0300 Subject: [PATCH 09/23] support inverse, onsided, dft_lenght --- python/tvm/relay/frontend/onnx.py | 56 +++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2924c2d0cd3c..136d784d8aa0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4899,17 +4899,21 @@ def _impl_v17(cls, inputs, attr, params): assert infer_type(input_tensor).checked_type.dtype in t1 input_shape = infer_shape(input_tensor) assert len(input_shape) >= 3 - n = len(input_shape) - 2 + if axis < 0: + axis = len(input_shape) - axis assert 1 <= axis <= len(input_shape) - 1 # dft_length - if dft_length is not None: - raise NotImplementedError("dft_length") - - assert inverse == 0, "inverse not supported" - assert onesided == 0, "onesided not supported" + if dft_length is None: + dft_length = input_shape[axis] + else: + dft_length_dtype = infer_type(dft_length).checked_type.dtype + assert dft_length_dtype in t2 + dft_length = int(infer_value(dft_length, params).numpy()) # ************************ + input_tensor = cls._maybe_crop_or_pad(input_tensor, axis, dft_length) + swap_axis = -1 re_input_tensor, im_input_tensor = cls._split_real_and_imag_parts(input_tensor) @@ -4921,24 +4925,37 @@ def _impl_v17(cls, inputs, attr, params): re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis) im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis) + if onesided: + re_input_tensor = cls._crop_onesided(re_input_tensor, axis) + im_input_tensor = cls._crop_onesided(im_input_tensor, axis) + output = cls._merge_real_and_imag_parts(re_input_tensor, im_input_tensor) return output + @classmethod + def _crop_axis(cls, tensor, axis, new_dim): + shape = infer_shape(tensor) + slices = [slice(0, a, 1) for a in shape] + slices[axis] = slice(0, new_dim, 1) + return _op.strided_slice( + tensor, + begin=[s.start for s in slices], + end=[s.stop for s in slices], + strides=[s.step for s in slices], + axes=list(range(len(shape))), + ) + @classmethod def _maybe_crop_or_pad(cls, input_tensor, axis, n_fft): - if input_tensor.shape[axis] != n_fft: - s = list(input_tensor.shape) - index = [slice(None)] * len(s) - if s[axis] > n_fft: - index[axis] = slice(0, n_fft) - _input_tensor = input_tensor[tuple(index)] + shape = infer_shape(input_tensor) + if shape[axis] != n_fft: + if shape[axis] > n_fft: + return cls._crop_axis(input_tensor, axis, n_fft) else: - index[axis] = slice(0, s[axis]) - s[axis] = n_fft - z = np.zeros(s, input_tensor.dtype.char) - z[tuple(index)] = input_tensor - _input_tensor = z + pad_width = [(0, 0)] * len(shape) + pad_width[axis] = (0, n_fft - shape[axis]) + return _op.nn.pad(input_tensor, pad_width) return input_tensor @classmethod @@ -4970,6 +4987,11 @@ def _merge_real_and_imag_parts(cls, re, im): output = _op.concatenate([re, im], axis=-1) return output + @classmethod + def _crop_onesided(cls, tensor, axis): + shape = infer_shape(tensor) + return cls._crop_axis(tensor, axis, shape[axis] // 2 + 1) + class NonMaxSuppression(OnnxOpConverter): """Operator converter for NonMaxSuppression.""" From 0b27f1887549bc4f0e0be9a89ee86c7b25dff89d Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 17:21:48 +0300 Subject: [PATCH 10/23] update tests for DFT --- tests/python/frontend/onnx/test_forward.py | 64 ++++++++++------------ 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dcd5de948862..a0cc4d41f5a2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5429,9 +5429,6 @@ def verify_eyelike(indata, dynamic=False): "test_cumsum_2d_negative_axis", "test_det_2d", "test_det_nd", - "test_dft", - "test_dft_axis", - "test_dft_inverse", "test_dropout_default", "test_dropout_default_mask", "test_dropout_default_mask_ratio", @@ -5591,6 +5588,9 @@ def test_onnx_nodes(target, dev, onnx_test): # satisfies onnx precision for bicubic interpolation atol = 1e-4 + if "dft" in test_dir: + atol = 1e-3 + model = onnx.load(os.path.join(test_dir, "model.onnx")) for test_data_dir in glob.glob(os.path.join(test_dir, "test_data_set*")): inputs = [] @@ -7945,32 +7945,30 @@ def verify_dft( ): input_names = ["input"] if _dft_length is not None: - input_names.append("_dft_length") + input_names.append("dft_length") node = onnx.helper.make_node( "DFT", inputs=input_names, outputs=["output"], - # domain="com.microsoft", axis=_axis, inverse=_inverse, onesided=_onesided, ) - inputs_info = [ - helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), - ] + nodes = [] if _dft_length is not None: - inputs_info.append( - helper.make_tensor_value_info( - "dft_length", TensorProto.INT32, [] - ) + nodes.append( + make_constant_node("dft_length", TensorProto.INT32, [], [_dft_length]), ) + nodes.append(node) graph = helper.make_graph( - [node], + nodes, "dft_test", - inputs=inputs_info, + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), + ], outputs=[ helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape), ], @@ -7978,36 +7976,32 @@ def verify_dft( model = helper.make_model(graph, producer_name="dft_test") - inputs = [_input] - if _dft_length is not None: - inputs.append(_dft_length) - verify_with_ort_with_inputs( model, - inputs, + [_input], [input_shape], target=target, dev=dev, - rtol=1e-4, - atol=1e-4, + rtol=1e-5, + atol=1e-5, + use_vm=False, ) - axis = 1 - inverse = 0 - onesided = 0 - - batch_size = 1 - n = 3 + batch_size = 5 + n = 2 D = 7 - input_shape = [batch_size] + n * [D] + [1] - output_shape = [batch_size] + n * [D] + [2] - if onesided == 1: - output_shape[axis] = output_shape[axis] // 2 + 1 - - input_tensor = np.random.normal(size=input_shape).astype("float32") - - verify_dft(axis, inverse, onesided, input_tensor) + for axis in range(1, n): + for inverse, onesided in [(0, 0), (0, 1), (1, 0)]: + for n_fft in [D, D - 1, D + 1]: + for c in [1, 2]: + input_shape = [batch_size] + n * [D] + [c] + output_shape = [batch_size] + n * [D] + [2] + if onesided == 1: + output_shape[axis] = output_shape[axis] // 2 + 1 + input_tensor = np.random.normal(size=input_shape).astype("float32") + verify_dft(axis, inverse, onesided, input_tensor, n_fft) + print("Local success!") @tvm.testing.parametrize_targets From aaca2b94274d2dbdd81055c3727c2678dc7d9adf Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 18:11:00 +0300 Subject: [PATCH 11/23] update TOPI test for DFT --- python/tvm/topi/stft.py | 4 ++-- tests/python/topi/python/test_topi_dft.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index efc01cb71cf5..9bcef98454ab 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -162,8 +162,8 @@ def gen_ir( with ib.for_range(0, n_fft) as k: k_idx = base_idx + k w = sign * -2 * pi * k * n / n_fft - cos_w = tir.cos(w) - sin_w = tir.sin(w) + cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) + sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w im_output_ptr[n_idx] += re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index e96862a67aaa..3acc32cf460b 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -21,6 +21,11 @@ import tvm.topi.testing +inverse = tvm.testing.parameter(False, True) +shape = tvm.testing.parameter((7,), (3, 7), (3, 4, 5)) +dtype = tvm.testing.parameter("float16", "float32", "float64") + + def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): if inverse: reference = np.fft.ifft(re + 1j * im) @@ -29,18 +34,17 @@ def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): return np.real(reference), np.imag(reference) -def dft(inverse, shape, dtype, dev, target): +def test_dft(target, dev, inverse, shape, dtype): Re = tvm.te.placeholder(shape, dtype=dtype, name="Re") Im = tvm.te.placeholder(shape, dtype=dtype, name="Im") with tvm.target.Target(target): - fcompute = lambda re_x, im_x: topi.dft(re_x, im_x, inverse=inverse) - fschedule = lambda outs: tvm.te.create_schedule([x.op for x in outs]) + fcompute = topi.dft + fschedule = topi.generic.schedule_extern - outs = fcompute(Re, Im) + outs = fcompute(Re, Im, inverse) s = fschedule(outs) - print(tvm.lower(s, [Re, Im, *outs], simple_mode=False)) f = tvm.build(s, [Re, Im, *outs], target) re_np = np.random.normal(size=shape).astype(dtype) @@ -54,10 +58,9 @@ def dft(inverse, shape, dtype, dev, target): f(re, im, re_out, im_out) re_reference, im_reference = numpy_reference(inverse, re_np, im_np) - tvm.testing.assert_allclose(re_out.numpy(), re_reference, rtol=1e-3) - tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=1e-3) + tvm.testing.assert_allclose(re_out.numpy(), re_reference, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=1e-3, atol=1e-3) if __name__ == '__main__': - dft(False, (3, 7, 7), "float32", tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0), "llvm") - dft(True, (3, 7, 7), "float32", tvm.runtime.Device(tvm.runtime.Device.kDLCPU, 0), "llvm") + tvm.testing.main() From f14e420e94fa9686a7f0c9532dd3efae0a1373d5 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 18:42:11 +0300 Subject: [PATCH 12/23] add documentation --- python/tvm/relay/op/_transform.py | 21 +-------------------- python/tvm/relay/op/transform.py | 20 +++++++++++++++++++- python/tvm/topi/stft.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 3c2fd7d41436..140b9835df6d 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -193,9 +193,8 @@ def stft_shape_func(attrs, inputs, _): # DFT @_reg.register_compute("dft") -def compute_dft(attrs, inputs, output_type): +def compute_dft(attrs, inputs, _): """Compute definition of DFT""" - # TODO(agladyshev): output_type not used return topi.dft( inputs[0], inputs[1], @@ -206,24 +205,6 @@ def compute_dft(attrs, inputs, output_type): _reg.register_strategy("dft", strategy.dft_strategy) -# TODO(agladyshev): remove? -# @script -# def _dft_shape_func(re_data, im_data): -# return (re_data.shape, im_data.shape) -# -# -# @_reg.register_shape_func("dft", True) -# def dft_shape_func(attrs, inputs, _): -# """ -# Shape func for DFT. -# """ -# return [ -# _dft_shape_func( -# inputs[0], inputs[1], -# ) -# ] - - # trilu _reg.register_strategy("trilu", strategy.trilu_strategy) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3b3fab522484..b9aa36330151 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1987,8 +1987,26 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -# TODO(agladyshev): add description def dft(re_data, im_data, inverse=False): + """ + Computes the discrete Fourier transform of input (calculation along the last axis). + This gives frequency components of the signal as they change over time. + Parameters + ---------- + re_data : relay.Expr + N-D tensor, real part of the input signal. + im_data : relay.Expr + N-D tensor, imaginary part of the input signal. + If the signal is real, then the values of this tensor are zeros. + inverse : bool + Whether to perform the inverse discrete fourier transform. + Returns + ------- + re_output : relay.Expr + The Fourier Transform of the input (Real part). + im_output : relay.Expr + The Fourier Transform of the input (Imaginary part). + """ return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 9bcef98454ab..4efb9377d976 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -130,6 +130,26 @@ def dft( im_data: te.Tensor, inverse: tir.IntImm, ): + """ + Computes the discrete Fourier transform of input (calculation along the last axis). + This gives frequency components of the signal as they change over time. + Parameters + ---------- + re_data : relay.Expr + N-D tensor, real part of the input signal. + im_data : relay.Expr + N-D tensor, imaginary part of the input signal. + If the signal is real, then the values of this tensor are zeros. + inverse : bool + Whether to perform the inverse discrete fourier transform. + Returns + ------- + re_output : relay.Expr + The Fourier Transform of the input (Real part). + im_output : relay.Expr + The Fourier Transform of the input (Imaginary part). + """ + def gen_ir( re_data_buf, im_data_buf, From 25009bb979de1457458831b018a71148eee98f2c Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 19:11:19 +0300 Subject: [PATCH 13/23] fix pylint --- python/tvm/topi/stft.py | 18 +++++++----------- tests/python/frontend/onnx/test_forward.py | 16 ++++++++-------- tests/python/topi/python/test_topi_dft.py | 6 +++++- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 4efb9377d976..237f771ee727 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -151,10 +151,10 @@ def dft( """ def gen_ir( - re_data_buf, - im_data_buf, - re_output_buf, - im_output_buf, + re_data_buf, + im_data_buf, + re_output_buf, + im_output_buf, ): ib = tir.ir_builder.create() re_data_ptr = ib.buffer_ptr(re_data_buf) @@ -169,11 +169,9 @@ def gen_ir( base_range *= shape[i] sign = -1 if inverse else 1 - factor = 1. / n_fft if inverse else 1. + factor = 1.0 / n_fft if inverse else 1.0 - with ib.for_range( - 0, base_range, kind="parallel" - ) as i: + with ib.for_range(0, base_range, kind="parallel") as i: base_idx = i * n_fft with ib.for_range(0, n_fft) as n: n_idx = base_idx + n @@ -197,9 +195,7 @@ def gen_ir( return te.extern( shape=output_shape, inputs=[re_data, im_data], - fcompute=lambda ins, outs: gen_ir( - ins[0], ins[1], outs[0], outs[1] - ), + fcompute=lambda ins, outs: gen_ir(ins[0], ins[1], outs[0], outs[1]), dtype=[re_data.dtype, im_data.dtype], name="dft_cpu", tag="dft_cpu", diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a0cc4d41f5a2..20201a110d19 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7936,12 +7936,13 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1): @tvm.testing.parametrize_targets def test_dft(target, dev): """test_dft""" + def verify_dft( - _axis, - _inverse, - _onesided, - _input, - _dft_length=None, + _axis, + _inverse, + _onesided, + _input, + _dft_length=None, ): input_names = ["input"] if _dft_length is not None: @@ -7982,8 +7983,8 @@ def verify_dft( [input_shape], target=target, dev=dev, - rtol=1e-5, - atol=1e-5, + rtol=1e-4, + atol=1e-4, use_vm=False, ) @@ -8001,7 +8002,6 @@ def verify_dft( output_shape[axis] = output_shape[axis] // 2 + 1 input_tensor = np.random.normal(size=input_shape).astype("float32") verify_dft(axis, inverse, onesided, input_tensor, n_fft) - print("Local success!") @tvm.testing.parametrize_targets diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index 3acc32cf460b..60632d0e17f3 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Test code for discrete Fourier transform.""" import numpy as np import tvm import tvm.testing @@ -26,6 +27,7 @@ dtype = tvm.testing.parameter("float16", "float32", "float64") +# pylint: disable=redefined-outer-name, invalid-name def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): if inverse: reference = np.fft.ifft(re + 1j * im) @@ -35,6 +37,8 @@ def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): def test_dft(target, dev, inverse, shape, dtype): + """Test for discrete Fourier transform.""" + Re = tvm.te.placeholder(shape, dtype=dtype, name="Re") Im = tvm.te.placeholder(shape, dtype=dtype, name="Im") @@ -62,5 +66,5 @@ def test_dft(target, dev, inverse, shape, dtype): tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=1e-3, atol=1e-3) -if __name__ == '__main__': +if __name__ == "__main__": tvm.testing.main() From 312b97ea83caca72f33655047043577fd49c5b49 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 21:05:39 +0300 Subject: [PATCH 14/23] fix cpplint --- src/relay/op/tensor/transform.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e746fc9e86ab..78668eb28859 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1924,9 +1924,11 @@ RELAY_REGISTER_OP("stft") // DFT TVM_REGISTER_NODE_TYPE(DFTAttrs); -bool DFTRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { +bool DFTRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // types: [re_data, im_data, output] - ICHECK_EQ(types.size(), 3) << "DFT: expects three types, two for the input and one for the output"; + ICHECK_EQ(types.size(), 3) + << "DFT: expects three types, two for the input and one for the output"; ICHECK_EQ(num_inputs, 2) << "DFT: expect 2 inputs but " << num_inputs << " provided"; const auto* re_data = types[0].as(); const auto* im_data = types[1].as(); From 5d044c6c990f0c08d8eb076e225fdfe7b26a1046 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Thu, 16 Feb 2023 21:26:19 +0300 Subject: [PATCH 15/23] fix cpplint --- src/relay/op/tensor/transform.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 78668eb28859..806a17442903 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1963,8 +1963,7 @@ Expr MakeDFT(Expr re_data, Expr im_data, Bool inverse) { TVM_REGISTER_GLOBAL("relay.op._make.dft").set_body_typed(MakeDFT); RELAY_REGISTER_OP("dft") - .describe( - R"doc(Computes the discrete Fourier transform of input.)doc" TVM_ADD_FILELINE) + .describe(R"doc(Computes the discrete Fourier transform of input.)doc" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("re_data", "Tensor", "Real part of input tensor.") .add_argument("im_data", "Tensor", "Imaginary part of input tensor.") From f68894fd1589f641eee426239e5838962c1a42be Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 17 Feb 2023 09:50:49 +0300 Subject: [PATCH 16/23] fix threshold for FP16 (ARM) --- tests/python/topi/python/test_topi_dft.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index 60632d0e17f3..ae82fdbe02dc 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -62,8 +62,13 @@ def test_dft(target, dev, inverse, shape, dtype): f(re, im, re_out, im_out) re_reference, im_reference = numpy_reference(inverse, re_np, im_np) - tvm.testing.assert_allclose(re_out.numpy(), re_reference, rtol=1e-3, atol=1e-3) - tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=1e-3, atol=1e-3) + + atol = rtol = 1e-3 + if dtype == "float16": + atol = rtol = 1e-1 + + tvm.testing.assert_allclose(re_out.numpy(), re_reference, rtol=rtol, atol=atol) + tvm.testing.assert_allclose(im_out.numpy(), im_reference, rtol=rtol, atol=atol) if __name__ == "__main__": From 4471e122aabb0e81b1c98a8a9c2e289edeec3c17 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 17 Feb 2023 13:22:14 +0300 Subject: [PATCH 17/23] add CUDA compute --- python/tvm/relay/op/strategy/cuda.py | 11 +++ python/tvm/topi/cuda/stft.py | 87 +++++++++++++++++++++++ tests/python/topi/python/test_topi_dft.py | 17 ++++- 3 files changed, 113 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index fc1691fe9ef0..856505050b0a 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1412,3 +1412,14 @@ def stft_strategy_cuda(attrs, inputs, out_type, target): name="stft.cuda", ) return strategy + + +@stft_strategy.register(["cuda", "gpu"]) +def stft_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_dft(topi.cuda.dft), + wrap_topi_schedule(topi.generic.schedule_extern), + name="dft.cuda", + ) + return strategy diff --git a/python/tvm/topi/cuda/stft.py b/python/tvm/topi/cuda/stft.py index 573c2ae39956..8e13be3f3e4d 100644 --- a/python/tvm/topi/cuda/stft.py +++ b/python/tvm/topi/cuda/stft.py @@ -133,3 +133,90 @@ def gen_ir( name="stft_cuda", tag="stft_cuda", ) + + +def dft( + re_data: te.Tensor, + im_data: te.Tensor, + inverse: tir.IntImm, +): + """ + Computes the discrete Fourier transform of input (calculation along the last axis). + This gives frequency components of the signal as they change over time. + Parameters + ---------- + re_data : relay.Expr + N-D tensor, real part of the input signal. + im_data : relay.Expr + N-D tensor, imaginary part of the input signal. + If the signal is real, then the values of this tensor are zeros. + inverse : bool + Whether to perform the inverse discrete fourier transform. + Returns + ------- + re_output : relay.Expr + The Fourier Transform of the input (Real part). + im_output : relay.Expr + The Fourier Transform of the input (Imaginary part). + """ + + def gen_ir( + re_data_buf, + im_data_buf, + re_output_buf, + im_output_buf, + ): + ib = tir.ir_builder.create() + re_data_ptr = ib.buffer_ptr(re_data_buf) + im_data_ptr = ib.buffer_ptr(im_data_buf) + re_output_ptr = ib.buffer_ptr(re_output_buf) + im_output_ptr = ib.buffer_ptr(im_output_buf) + + shape = re_data.shape + n_fft = shape[len(shape) - 1] + base_range = 1 + for i in range(len(shape) - 1): + base_range *= shape[i] + + sign = -1 if inverse else 1 + factor = 1.0 / n_fft if inverse else 1.0 + + max_threads = _get_max_threads(base_range) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(base_range, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + tid = bx * max_threads + tx + with ib.if_scope(tid < base_range): + base_idx = tid * n_fft + with ib.for_range(0, n_fft) as n: + n_idx = base_idx + n + re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0) + im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) + with ib.for_range(0, n_fft) as k: + k_idx = base_idx + k + w = sign * -2 * pi * k * n / n_fft + cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) + sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) + re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w + im_output_ptr[n_idx] += re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w + + re_output_ptr[n_idx] *= tir.Cast(re_output_ptr.dtype, factor) + im_output_ptr[n_idx] *= tir.Cast(im_output_ptr.dtype, factor) + + return ib.get() + + output_shape = [re_data.shape] * 2 + + return te.extern( + shape=output_shape, + inputs=[re_data, im_data], + fcompute=lambda ins, outs: gen_ir(ins[0], ins[1], outs[0], outs[1]), + dtype=[re_data.dtype, im_data.dtype], + name="dft_cuda", + tag="dft_cuda", + ) diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index ae82fdbe02dc..9d5a55832350 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -38,13 +38,26 @@ def numpy_reference(inverse, re: np.ndarray, im: np.ndarray): def test_dft(target, dev, inverse, shape, dtype): """Test for discrete Fourier transform.""" + implementations = { + "generic": ( + topi.dft, + topi.generic.schedule_extern, + ), + "gpu": ( + topi.cuda.dft, + topi.cuda.schedule_scan, + ), + "nvptx": ( + topi.cuda.dft, + topi.cuda.schedule_scan, + ), + } Re = tvm.te.placeholder(shape, dtype=dtype, name="Re") Im = tvm.te.placeholder(shape, dtype=dtype, name="Im") with tvm.target.Target(target): - fcompute = topi.dft - fschedule = topi.generic.schedule_extern + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) outs = fcompute(Re, Im, inverse) s = fschedule(outs) From 2f933be865cfb8d23da1602e7e969cda82fd65c2 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 17 Feb 2023 13:39:18 +0300 Subject: [PATCH 18/23] fix pylint --- python/tvm/relay/op/strategy/cuda.py | 4 ++-- python/tvm/topi/cuda/stft.py | 22 +++++++++++++--------- tests/python/topi/python/test_topi_dft.py | 4 ++-- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 856505050b0a..01c0654c3eb5 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1414,8 +1414,8 @@ def stft_strategy_cuda(attrs, inputs, out_type, target): return strategy -@stft_strategy.register(["cuda", "gpu"]) -def stft_strategy_cuda(attrs, inputs, out_type, target): +@dft_strategy.register(["cuda", "gpu"]) +def dft_strategy_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_dft(topi.cuda.dft), diff --git a/python/tvm/topi/cuda/stft.py b/python/tvm/topi/cuda/stft.py index 8e13be3f3e4d..6b16e334446d 100644 --- a/python/tvm/topi/cuda/stft.py +++ b/python/tvm/topi/cuda/stft.py @@ -136,9 +136,9 @@ def gen_ir( def dft( - re_data: te.Tensor, - im_data: te.Tensor, - inverse: tir.IntImm, + re_data: te.Tensor, + im_data: te.Tensor, + inverse: tir.IntImm, ): """ Computes the discrete Fourier transform of input (calculation along the last axis). @@ -161,10 +161,10 @@ def dft( """ def gen_ir( - re_data_buf, - im_data_buf, - re_output_buf, - im_output_buf, + re_data_buf, + im_data_buf, + re_output_buf, + im_output_buf, ): ib = tir.ir_builder.create() re_data_ptr = ib.buffer_ptr(re_data_buf) @@ -202,8 +202,12 @@ def gen_ir( w = sign * -2 * pi * k * n / n_fft cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) - re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w - im_output_ptr[n_idx] += re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w + re_output_ptr[n_idx] += ( + re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w + ) + im_output_ptr[n_idx] += ( + re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w + ) re_output_ptr[n_idx] *= tir.Cast(re_output_ptr.dtype, factor) im_output_ptr[n_idx] *= tir.Cast(im_output_ptr.dtype, factor) diff --git a/tests/python/topi/python/test_topi_dft.py b/tests/python/topi/python/test_topi_dft.py index 9d5a55832350..abab272e601d 100644 --- a/tests/python/topi/python/test_topi_dft.py +++ b/tests/python/topi/python/test_topi_dft.py @@ -45,11 +45,11 @@ def test_dft(target, dev, inverse, shape, dtype): ), "gpu": ( topi.cuda.dft, - topi.cuda.schedule_scan, + topi.cuda.schedule_extern, ), "nvptx": ( topi.cuda.dft, - topi.cuda.schedule_scan, + topi.cuda.schedule_extern, ), } From df652908a7f904b9acddb3d67a85b70bdb48bbcd Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 22 Feb 2023 12:17:34 +0300 Subject: [PATCH 19/23] fix doc string --- python/tvm/relay/op/transform.py | 4 ++++ python/tvm/topi/cuda/stft.py | 4 ++++ python/tvm/topi/stft.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b9aa36330151..1b0cc3588014 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1991,15 +1991,19 @@ def dft(re_data, im_data, inverse=False): """ Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. + Parameters ---------- re_data : relay.Expr N-D tensor, real part of the input signal. + im_data : relay.Expr N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. + inverse : bool Whether to perform the inverse discrete fourier transform. + Returns ------- re_output : relay.Expr diff --git a/python/tvm/topi/cuda/stft.py b/python/tvm/topi/cuda/stft.py index 6b16e334446d..2f588f4a8fed 100644 --- a/python/tvm/topi/cuda/stft.py +++ b/python/tvm/topi/cuda/stft.py @@ -143,15 +143,19 @@ def dft( """ Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. + Parameters ---------- re_data : relay.Expr N-D tensor, real part of the input signal. + im_data : relay.Expr N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. + inverse : bool Whether to perform the inverse discrete fourier transform. + Returns ------- re_output : relay.Expr diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 237f771ee727..4e44c33803f5 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -133,15 +133,19 @@ def dft( """ Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. + Parameters ---------- re_data : relay.Expr N-D tensor, real part of the input signal. + im_data : relay.Expr N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. + inverse : bool Whether to perform the inverse discrete fourier transform. + Returns ------- re_output : relay.Expr From cce4151eba589f4d5eaed83fe073c50c17f1999e Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 22 Feb 2023 13:53:38 +0300 Subject: [PATCH 20/23] code review fixes for ONNX front-end --- python/tvm/relay/frontend/onnx.py | 16 +++++----------- tests/python/frontend/onnx/test_forward.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 136d784d8aa0..d3a542282c08 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4900,8 +4900,8 @@ def _impl_v17(cls, inputs, attr, params): input_shape = infer_shape(input_tensor) assert len(input_shape) >= 3 if axis < 0: - axis = len(input_shape) - axis - assert 1 <= axis <= len(input_shape) - 1 + axis = len(input_shape) + axis + assert 1 <= axis <= len(input_shape) - 1, "axis is out of bounds" # dft_length if dft_length is None: @@ -4929,9 +4929,7 @@ def _impl_v17(cls, inputs, attr, params): re_input_tensor = cls._crop_onesided(re_input_tensor, axis) im_input_tensor = cls._crop_onesided(im_input_tensor, axis) - output = cls._merge_real_and_imag_parts(re_input_tensor, im_input_tensor) - - return output + return cls._merge_real_and_imag_parts(re_input_tensor, im_input_tensor) @classmethod def _crop_axis(cls, tensor, axis, new_dim): @@ -4975,17 +4973,13 @@ def _split_real_and_imag_parts(cls, tensor): else: re, im = _op.split(tensor, 2, -1) - re = _op.squeeze(re, -1) - im = _op.squeeze(im, -1) - - return re, im + return _op.squeeze(re, -1), _op.squeeze(im, -1) @classmethod def _merge_real_and_imag_parts(cls, re, im): re = _op.expand_dims(re, axis=-1) im = _op.expand_dims(im, axis=-1) - output = _op.concatenate([re, im], axis=-1) - return output + return _op.concatenate([re, im], axis=-1) @classmethod def _crop_onesided(cls, tensor, axis): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 20201a110d19..f6e3984aa03e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7992,7 +7992,7 @@ def verify_dft( n = 2 D = 7 - for axis in range(1, n): + for axis in list(range(1, n)) + [-2]: for inverse, onesided in [(0, 0), (0, 1), (1, 0)]: for n_fft in [D, D - 1, D + 1]: for c in [1, 2]: From 0c81b0ad9c22e98dbd4d0fb24aecf96b3aa17a26 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 22 Feb 2023 13:59:04 +0300 Subject: [PATCH 21/23] code review fixes for TOPI --- python/tvm/topi/cuda/stft.py | 3 ++- python/tvm/topi/stft.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/stft.py b/python/tvm/topi/cuda/stft.py index 2f588f4a8fed..d08f41ab8912 100644 --- a/python/tvm/topi/cuda/stft.py +++ b/python/tvm/topi/cuda/stft.py @@ -201,9 +201,10 @@ def gen_ir( n_idx = base_idx + n re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0) im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) + _w = sign * -2 * pi * n / n_fft with ib.for_range(0, n_fft) as k: k_idx = base_idx + k - w = sign * -2 * pi * k * n / n_fft + w = _w * k cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) re_output_ptr[n_idx] += ( diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/stft.py index 4e44c33803f5..64a804a851ab 100644 --- a/python/tvm/topi/stft.py +++ b/python/tvm/topi/stft.py @@ -181,9 +181,10 @@ def gen_ir( n_idx = base_idx + n re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0) im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) + _w = sign * -2 * pi * n / n_fft with ib.for_range(0, n_fft) as k: k_idx = base_idx + k - w = sign * -2 * pi * k * n / n_fft + w = _w * k cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) re_output_ptr[n_idx] += re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w From 9f0f486c5a94e91aac349edffd4b93da2ce86bba Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Tue, 28 Feb 2023 10:24:39 +0300 Subject: [PATCH 22/23] rename: stft.py -> signal.py --- python/tvm/topi/__init__.py | 2 +- python/tvm/topi/cuda/__init__.py | 2 +- python/tvm/topi/cuda/{stft.py => signal.py} | 0 python/tvm/topi/{stft.py => signal.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename python/tvm/topi/cuda/{stft.py => signal.py} (100%) rename python/tvm/topi/{stft.py => signal.py} (100%) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 75867136e09e..3584191b86cc 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -47,7 +47,7 @@ from .einsum import * from .unique import * from .searchsorted import * -from .stft import * +from .signal import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index b746c95c0fc1..a6ced5bcf9bc 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -61,4 +61,4 @@ from .transform import * from .unique import * from .searchsorted import * -from .stft import * +from .signal import * diff --git a/python/tvm/topi/cuda/stft.py b/python/tvm/topi/cuda/signal.py similarity index 100% rename from python/tvm/topi/cuda/stft.py rename to python/tvm/topi/cuda/signal.py diff --git a/python/tvm/topi/stft.py b/python/tvm/topi/signal.py similarity index 100% rename from python/tvm/topi/stft.py rename to python/tvm/topi/signal.py From 580437735c13bb0e4e11ca9889816919b8fff0d1 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Tue, 28 Feb 2023 14:24:45 +0300 Subject: [PATCH 23/23] pass input_shape and output_shape to verify_dft --- tests/python/frontend/onnx/test_forward.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f6e3984aa03e..116c023caadb 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7941,8 +7941,9 @@ def verify_dft( _axis, _inverse, _onesided, - _input, - _dft_length=None, + _dft_length, + _input_shape, + _output_shape, ): input_names = ["input"] if _dft_length is not None: @@ -7968,19 +7969,20 @@ def verify_dft( nodes, "dft_test", inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("input", TensorProto.FLOAT, _input_shape), ], outputs=[ - helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape), + helper.make_tensor_value_info("output", TensorProto.FLOAT, _output_shape), ], ) model = helper.make_model(graph, producer_name="dft_test") + _input = np.random.normal(size=_input_shape).astype("float32") verify_with_ort_with_inputs( model, [_input], - [input_shape], + [_input_shape], target=target, dev=dev, rtol=1e-4, @@ -8000,8 +8002,7 @@ def verify_dft( output_shape = [batch_size] + n * [D] + [2] if onesided == 1: output_shape[axis] = output_shape[axis] // 2 + 1 - input_tensor = np.random.normal(size=input_shape).astype("float32") - verify_dft(axis, inverse, onesided, input_tensor, n_fft) + verify_dft(axis, inverse, onesided, n_fft, input_shape, output_shape) @tvm.testing.parametrize_targets